diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index 54e8a6d93..2cad504f3 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -140,7 +140,7 @@ jobs: - name: Install Colossal-AI run: | - CUDA_EXT=1 pip install -v -e . + BUILD_EXT=1 pip install -v -e . pip install -r requirements/requirements-test.txt - name: Store Colossal-AI Cache @@ -160,9 +160,7 @@ jobs: --ignore tests/test_gptq \ --ignore tests/test_infer_ops \ --ignore tests/test_legacy \ - --ignore tests/test_moe \ --ignore tests/test_smoothquant \ - --ignore tests/test_checkpoint_io \ tests/ env: LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 diff --git a/.github/workflows/build_on_schedule.yml b/.github/workflows/build_on_schedule.yml index 5b0103eb7..ae1a5275e 100644 --- a/.github/workflows/build_on_schedule.yml +++ b/.github/workflows/build_on_schedule.yml @@ -12,7 +12,7 @@ jobs: if: github.repository == 'hpcaitech/ColossalAI' runs-on: [self-hosted, gpu] container: - image: hpcaitech/pytorch-cuda:2.0.0-11.7.0 + image: hpcaitech/pytorch-cuda:2.1.0-12.1.0 options: --gpus all --rm -v /dev/shm -v /data/scratch/llama-tiny:/data/scratch/llama-tiny timeout-minutes: 90 steps: @@ -23,6 +23,7 @@ jobs: ngpu=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) endIndex=$(($ngpu-1)) for i in $(seq 0 $endIndex); + do gpu_used=$(nvidia-smi -i $i --query-gpu=memory.used --format=csv,noheader,nounits) [ "$gpu_used" -gt "2000" ] && avai=false done @@ -54,7 +55,7 @@ jobs: if: steps.check-avai.outputs.avai == 'true' run: | [ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/ - CUDA_EXT=1 pip install -v -e . + BUILD_EXT=1 pip install -v -e . cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/ pip install -r requirements/requirements-test.txt diff --git a/.github/workflows/example_check_on_dispatch.yml b/.github/workflows/example_check_on_dispatch.yml index 02e30f52a..bba321fd2 100644 --- a/.github/workflows/example_check_on_dispatch.yml +++ b/.github/workflows/example_check_on_dispatch.yml @@ -45,9 +45,9 @@ jobs: fail-fast: false matrix: ${{fromJson(needs.manual_check_matrix_preparation.outputs.matrix)}} container: - image: hpcaitech/pytorch-cuda:2.0.0-11.7.0 + image: hpcaitech/pytorch-cuda:2.1.0-12.1.0 options: --gpus all --rm -v /data/scratch/examples-data:/data/ - timeout-minutes: 10 + timeout-minutes: 15 steps: - name: 📚 Checkout uses: actions/checkout@v3 diff --git a/.github/workflows/example_check_on_pr.yml b/.github/workflows/example_check_on_pr.yml index 6d6952aa1..fcff8e569 100644 --- a/.github/workflows/example_check_on_pr.yml +++ b/.github/workflows/example_check_on_pr.yml @@ -77,7 +77,7 @@ jobs: fail-fast: false matrix: ${{fromJson(needs.detect-changed-example.outputs.matrix)}} container: - image: hpcaitech/pytorch-cuda:2.0.0-11.7.0 + image: hpcaitech/pytorch-cuda:2.1.0-12.1.0 options: --gpus all --rm -v /data/scratch/examples-data:/data/ timeout-minutes: 20 concurrency: diff --git a/.github/workflows/example_check_on_schedule.yml b/.github/workflows/example_check_on_schedule.yml index 919fa5092..abb947949 100644 --- a/.github/workflows/example_check_on_schedule.yml +++ b/.github/workflows/example_check_on_schedule.yml @@ -34,7 +34,7 @@ jobs: fail-fast: false matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} container: - image: hpcaitech/pytorch-cuda:2.0.0-11.7.0 + image: hpcaitech/pytorch-cuda:2.1.0-12.1.0 timeout-minutes: 10 steps: - name: 📚 Checkout diff --git a/.github/workflows/run_chatgpt_examples.yml b/.github/workflows/run_chatgpt_examples.yml index f9e9f4009..bb0ceb4a8 100644 --- a/.github/workflows/run_chatgpt_examples.yml +++ b/.github/workflows/run_chatgpt_examples.yml @@ -18,7 +18,7 @@ jobs: github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' runs-on: [self-hosted, gpu] container: - image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 + image: hpcaitech/pytorch-cuda:2.1.0-12.1.0 options: --gpus all --rm -v /data/scratch/github_actions/chat:/data/scratch/github_actions/chat --shm-size=10.24gb timeout-minutes: 30 defaults: diff --git a/.github/workflows/run_chatgpt_unit_tests.yml b/.github/workflows/run_chatgpt_unit_tests.yml index ec5c8ffa3..7986889e0 100644 --- a/.github/workflows/run_chatgpt_unit_tests.yml +++ b/.github/workflows/run_chatgpt_unit_tests.yml @@ -20,7 +20,7 @@ jobs: github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' runs-on: [self-hosted, gpu] container: - image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 + image: hpcaitech/pytorch-cuda:2.1.0-12.1.0 options: --gpus all --rm -v /data/scratch/chatgpt:/data/scratch/chatgpt timeout-minutes: 30 defaults: diff --git a/.github/workflows/run_colossalqa_unit_tests.yml b/.github/workflows/run_colossalqa_unit_tests.yml index 763db2772..00944b92d 100644 --- a/.github/workflows/run_colossalqa_unit_tests.yml +++ b/.github/workflows/run_colossalqa_unit_tests.yml @@ -19,7 +19,7 @@ jobs: github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' runs-on: [self-hosted, gpu] container: - image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 + image: hpcaitech/pytorch-cuda:2.1.0-12.1.0 volumes: - /data/scratch/test_data_colossalqa:/data/scratch/test_data_colossalqa - /data/scratch/llama-tiny:/data/scratch/llama-tiny @@ -51,4 +51,4 @@ jobs: TEST_DATA_PATH_EN: /data/scratch/test_data_colossalqa/companies.txt TEST_DATA_PATH_ZH: /data/scratch/test_data_colossalqa/companies_zh.txt TEST_DOCUMENT_LOADER_DATA_PATH: /data/scratch/test_data_colossalqa/tests/* - SQL_FILE_PATH: /data/scratch/test_data_colossalqa/sql_file_path \ No newline at end of file + SQL_FILE_PATH: /data/scratch/test_data_colossalqa/sql_file_path diff --git a/MANIFEST.in b/MANIFEST.in index ad26b634a..f0a5611ef 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,4 @@ include *.txt README.md recursive-include requirements *.txt recursive-include colossalai *.cpp *.h *.cu *.tr *.cuh *.cc *.pyi -recursive-include op_builder *.py +recursive-include extensions *.py *.cpp *.h *.cu *.tr *.cuh *.cc *.pyi diff --git a/README.md b/README.md index 971f4375a..13757eece 100644 --- a/README.md +++ b/README.md @@ -141,25 +141,26 @@ distributed training and inference in a few lines. [[HuggingFace model weights]](https://huggingface.co/hpcai-tech/Colossal-LLaMA-2-13b-base) [[Modelscope model weights]](https://www.modelscope.cn/models/colossalai/Colossal-LLaMA-2-13b-base/summary) -| Model | Backbone | Tokens Consumed | MMLU (5-shot) | CMMLU (5-shot)| AGIEval (5-shot) | GAOKAO (0-shot) | CEval (5-shot) | -| :----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-------------: | :-------------: | -| Baichuan-7B | - | 1.2T | 42.32 (42.30) | 44.53 (44.02) | 38.72 | 36.74 | 42.80 | -| Baichuan-13B-Base | - | 1.4T | 50.51 (51.60) | 55.73 (55.30) | 47.20 | 51.41 | 53.60 | -| Baichuan2-7B-Base | - | 2.6T | 46.97 (54.16) | 57.67 (57.07) | 45.76 | 52.60 | 54.00 | -| Baichuan2-13B-Base | - | 2.6T | 54.84 (59.17) | 62.62 (61.97) | 52.08 | 58.25 | 58.10 | -| ChatGLM-6B | - | 1.0T | 39.67 (40.63) | 41.17 (-) | 40.10 | 36.53 | 38.90 | -| ChatGLM2-6B | - | 1.4T | 44.74 (45.46) | 49.40 (-) | 46.36 | 45.49 | 51.70 | -| InternLM-7B | - | 1.6T | 46.70 (51.00) | 52.00 (-) | 44.77 | 61.64 | 52.80 | -| Qwen-7B | - | 2.2T | 54.29 (56.70) | 56.03 (58.80) | 52.47 | 56.42 | 59.60 | -| Llama-2-7B | - | 2.0T | 44.47 (45.30) | 32.97 (-) | 32.60 | 25.46 | - | -| Linly-AI/Chinese-LLaMA-2-7B-hf | Llama-2-7B | 1.0T | 37.43 | 29.92 | 32.00 | 27.57 | - | -| wenge-research/yayi-7b-llama2 | Llama-2-7B | - | 38.56 | 31.52 | 30.99 | 25.95 | - | -| ziqingyang/chinese-llama-2-7b | Llama-2-7B | - | 33.86 | 34.69 | 34.52 | 25.18 | 34.2 | -| TigerResearch/tigerbot-7b-base | Llama-2-7B | 0.3T | 43.73 | 42.04 | 37.64 | 30.61 | - | -| LinkSoul/Chinese-Llama-2-7b | Llama-2-7B | - | 48.41 | 38.31 | 38.45 | 27.72 | - | -| FlagAlpha/Atom-7B | Llama-2-7B | 0.1T | 49.96 | 41.10 | 39.83 | 33.00 | - | -| IDEA-CCNL/Ziya-LLaMA-13B-v1.1 | Llama-13B | 0.11T | 50.25 | 40.99 | 40.04 | 30.54 | - | -| **Colossal-LLaMA-2-7b-base** | Llama-2-7B | **0.0085T** | 53.06 | 49.89 | 51.48 | 58.82 | 50.2 | +| Model | Backbone | Tokens Consumed | MMLU (5-shot) | CMMLU (5-shot)| AGIEval (5-shot) | GAOKAO (0-shot) | CEval (5-shot) | +| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-------------: | :-------------: | +| Baichuan-7B | - | 1.2T | 42.32 (42.30) | 44.53 (44.02) | 38.72 | 36.74 | 42.80 | +| Baichuan-13B-Base | - | 1.4T | 50.51 (51.60) | 55.73 (55.30) | 47.20 | 51.41 | 53.60 | +| Baichuan2-7B-Base | - | 2.6T | 46.97 (54.16) | 57.67 (57.07) | 45.76 | 52.60 | 54.00 | +| Baichuan2-13B-Base | - | 2.6T | 54.84 (59.17) | 62.62 (61.97) | 52.08 | 58.25 | 58.10 | +| ChatGLM-6B | - | 1.0T | 39.67 (40.63) | 41.17 (-) | 40.10 | 36.53 | 38.90 | +| ChatGLM2-6B | - | 1.4T | 44.74 (45.46) | 49.40 (-) | 46.36 | 45.49 | 51.70 | +| InternLM-7B | - | 1.6T | 46.70 (51.00) | 52.00 (-) | 44.77 | 61.64 | 52.80 | +| Qwen-7B | - | 2.2T | 54.29 (56.70) | 56.03 (58.80) | 52.47 | 56.42 | 59.60 | +| Llama-2-7B | - | 2.0T | 44.47 (45.30) | 32.97 (-) | 32.60 | 25.46 | - | +| Linly-AI/Chinese-LLaMA-2-7B-hf | Llama-2-7B | 1.0T | 37.43 | 29.92 | 32.00 | 27.57 | - | +| wenge-research/yayi-7b-llama2 | Llama-2-7B | - | 38.56 | 31.52 | 30.99 | 25.95 | - | +| ziqingyang/chinese-llama-2-7b | Llama-2-7B | - | 33.86 | 34.69 | 34.52 | 25.18 | 34.2 | +| TigerResearch/tigerbot-7b-base | Llama-2-7B | 0.3T | 43.73 | 42.04 | 37.64 | 30.61 | - | +| LinkSoul/Chinese-Llama-2-7b | Llama-2-7B | - | 48.41 | 38.31 | 38.45 | 27.72 | - | +| FlagAlpha/Atom-7B | Llama-2-7B | 0.1T | 49.96 | 41.10 | 39.83 | 33.00 | - | +| IDEA-CCNL/Ziya-LLaMA-13B-v1.1 | Llama-13B | 0.11T | 50.25 | 40.99 | 40.04 | 30.54 | - | +| **Colossal-LLaMA-2-7b-base** | Llama-2-7B | **0.0085T** | 53.06 | 49.89 | 51.48 | 58.82 | 50.2 | +| **Colossal-LLaMA-2-13b-base** | Llama-2-13B | **0.025T** | 56.42 | 61.80 | 54.69 | 69.53 | 60.3 | ### ColossalChat diff --git a/applications/Chat/coati/trainer/ppo.py b/applications/Chat/coati/trainer/ppo.py index d69666898..330e4e0e3 100644 --- a/applications/Chat/coati/trainer/ppo.py +++ b/applications/Chat/coati/trainer/ppo.py @@ -10,7 +10,7 @@ from torch.utils.data import DataLoader, DistributedSampler from tqdm import tqdm from transformers import PreTrainedTokenizerBase -from colossalai.utils import get_current_device +from colossalai.accelerator import get_accelerator from .base import OnPolicyTrainer from .callbacks import Callback @@ -105,7 +105,7 @@ class PPOTrainer(OnPolicyTrainer): self.critic_optim = critic_optim self.offload_inference_models = offload_inference_models - self.device = get_current_device() + self.device = get_accelerator().get_current_device() def _before_fit( self, diff --git a/applications/Chat/coati/trainer/strategies/colossalai.py b/applications/Chat/coati/trainer/strategies/colossalai.py index 7129edb06..95f016786 100644 --- a/applications/Chat/coati/trainer/strategies/colossalai.py +++ b/applications/Chat/coati/trainer/strategies/colossalai.py @@ -6,7 +6,6 @@ import torch.nn as nn import colossalai from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel -from colossalai.utils import get_current_device from colossalai.zero.gemini.gemini_ddp import GeminiDDP from .ddp import DDPStrategy @@ -158,9 +157,19 @@ class GeminiStrategy(DDPStrategy): warnings.warn(f"Stage 3 only supports fp16. Precision is set to fp16.") + # colossalai has changed api for get_current_device in 0.3.4 version or newer + try: + from colossalai.accelerator import get_accelerator + + chunk_init_device = get_accelerator().get_current_device() + except: + from colossalai.utils import get_current_device + + chunk_init_device = get_current_device() + # NOTE: dist should be initialized before calling get_current_device() plugin_initializer = lambda: GeminiPlugin( - chunk_init_device=get_current_device(), + chunk_init_device=chunk_init_device, placement_policy=placement_policy, shard_param_frac=shard_param_frac, offload_optim_frac=offload_optim_frac, diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/tokenizer/init_tokenizer.py b/applications/Colossal-LLaMA-2/colossal_llama2/tokenizer/init_tokenizer.py index 43297633d..439135503 100644 --- a/applications/Colossal-LLaMA-2/colossal_llama2/tokenizer/init_tokenizer.py +++ b/applications/Colossal-LLaMA-2/colossal_llama2/tokenizer/init_tokenizer.py @@ -6,12 +6,12 @@ Initialize new tokenizer for continual pre-training """ import argparse -import os import json +import os from typing import List, Union -from transformers.models.llama.tokenization_llama import LlamaTokenizer from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model +from transformers.models.llama.tokenization_llama import LlamaTokenizer from colossalai.logging import get_dist_logger diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py b/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py index 079faaace..9f6c9c1cc 100644 --- a/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py +++ b/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py @@ -16,7 +16,10 @@ import torch def unwrap(model): - return model.unwrap().module + if hasattr(model, "module"): + return unwrap_model(model.module) + else: + return model def neftune_post_forward_hook(module, input, output): diff --git a/applications/Colossal-LLaMA-2/train.py b/applications/Colossal-LLaMA-2/train.py index 41b4ef031..92863e8e4 100644 --- a/applications/Colossal-LLaMA-2/train.py +++ b/applications/Colossal-LLaMA-2/train.py @@ -1,44 +1,37 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- """ -Continual Pre-training of LLaMA-2 developed by Colossal-AI Team +Continual Pre-training of LLaMA-2 developed by Colossal-AI Team """ -import json import argparse +import json import os import resource from contextlib import nullcontext -from tqdm import tqdm import torch import torch.distributed as dist +from colossal_llama2.dataset.loader import ( + DataCollatorForSupervisedDataset, + StatefulDistributedSampler, + load_tokenized_dataset, + setup_distributed_dataloader, +) +from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint +from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention +from colossal_llama2.utils.froze import freeze_non_embeds_parameters from torch.utils.tensorboard import SummaryWriter -from transformers import LlamaTokenizer, LlamaForCausalLM, LlamaConfig +from tqdm import tqdm +from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer import colossalai from colossalai.booster import Booster -from colossalai.booster.plugin import ( - GeminiPlugin, - LowLevelZeroPlugin, - HybridParallelPlugin, -) +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin from colossalai.cluster import DistCoordinator from colossalai.lazy import LazyInitContext from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device - -from colossal_llama2.dataset.loader import ( - load_tokenized_dataset, - setup_distributed_dataloader, - DataCollatorForSupervisedDataset, - StatefulDistributedSampler, -) - -from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention -from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint -from colossal_llama2.utils.froze import freeze_non_embeds_parameters def get_model_numel(model: torch.nn.Module) -> int: @@ -215,9 +208,18 @@ def main() -> None: # ====================================================== # Initialize Model, Objective, Optimizer and LR Scheduler # ====================================================== - init_ctx = ( - LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext() - ) + + # colossalai has changed api for get_current_device in 0.3.4 version or newer + try: + from colossalai.accelerator import get_accelerator + + current_device = get_accelerator().get_current_device() + except: + from colossalai.utils import get_current_device + + current_device = get_current_device() + + init_ctx = LazyInitContext(default_device=current_device) if isinstance(plugin, (GeminiPlugin,)) else nullcontext() with init_ctx: model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained)) # Freeze part of parameters. @@ -320,7 +322,7 @@ def main() -> None: initial=start_step, ) as pbar: for step, batch in pbar: - batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)} + batch = {k: v.to(current_device) for k, v in batch.items() if isinstance(v, torch.Tensor)} batch_output = model(**batch) @@ -372,9 +374,7 @@ def main() -> None: # Final save. coordinator.print_on_master("Start saving final model checkpoint") booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True) - coordinator.print_on_master( - f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}" - ) + coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}") coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") diff --git a/applications/ColossalQA/colossalqa/local/llm.py b/applications/ColossalQA/colossalqa/local/llm.py index ff7346adc..0aa383e9d 100644 --- a/applications/ColossalQA/colossalqa/local/llm.py +++ b/applications/ColossalQA/colossalqa/local/llm.py @@ -136,6 +136,19 @@ class ColossalLLM(LLM): """Get the identifying parameters.""" return {"n": self.n} + def get_token_ids(self, text: str) -> List[int]: + """Return the ordered ids of the tokens in a text. + + Args: + text: The string input to tokenize. + + Returns: + A list of ids corresponding to the tokens in the text, in order they occur + in the text. + """ + # use the colossal llm's tokenizer instead of langchain's cached GPT2 tokenizer + return self.api.tokenizer.encode(text) + class VllmLLM(LLM): """ diff --git a/colossalai/__init__.py b/colossalai/__init__.py index 7da555903..6b7f5d055 100644 --- a/colossalai/__init__.py +++ b/colossalai/__init__.py @@ -1,4 +1,5 @@ from .initialize import launch, launch_from_openmpi, launch_from_slurm, launch_from_torch +from . import accelerator try: # .version will be created by setup.py diff --git a/colossalai/accelerator/README.md b/colossalai/accelerator/README.md new file mode 100644 index 000000000..8c644493b --- /dev/null +++ b/colossalai/accelerator/README.md @@ -0,0 +1,20 @@ +# 🚀 Accelerator + +## 🔗 Table of Contents + +- [🚀 Accelerator](#-accelerator) + - [🔗 Table of Contents](#-table-of-contents) + - [📚 Introduction](#-introduction) + - [📌 Design and Acknowledgement](#-design-and-acknowledgement) + +## 📚 Introduction + +This module offers a layer of abstraction for ColossalAI. With this module, the user can easily switch between different accelerator backends, such as Nvidia GPUs, Huawei NPUs, etc. This module is an attempt to make users' code portable across different hardware platform with a simple `auto_set_accelerator()` API. + +## 📌 Design and Acknowledgement + +Our `accelerator` module is heavily inspired by [`deepspeed/accelerator`](https://www.deepspeed.ai/tutorials/accelerator-abstraction-interface/). We found that it is a very well-designed and well-structured module that can be easily integrated into our project. We would like to thank the DeepSpeed team for their great work. + +We implemented this accelerator module from scratch. At the same time, we have implemented our own modifications: +1. we updated the accelerator API names to be aligned with PyTorch's native API names. +2. we did not include the `op builder` in the `accelerator`. Instead, we have reconstructed our `kernel` module to automatically match the accelerator and its corresponding kernel implementations, so as to make modules less tangled. diff --git a/colossalai/accelerator/__init__.py b/colossalai/accelerator/__init__.py new file mode 100644 index 000000000..1405133af --- /dev/null +++ b/colossalai/accelerator/__init__.py @@ -0,0 +1,15 @@ +from .api import auto_set_accelerator, get_accelerator, set_accelerator +from .base_accelerator import BaseAccelerator +from .cpu_accelerator import CpuAccelerator +from .cuda_accelerator import CudaAccelerator +from .npu_accelerator import NpuAccelerator + +__all__ = [ + "get_accelerator", + "set_accelerator", + "auto_set_accelerator", + "BaseAccelerator", + "CudaAccelerator", + "NpuAccelerator", + "CpuAccelerator", +] diff --git a/colossalai/accelerator/api.py b/colossalai/accelerator/api.py new file mode 100644 index 000000000..02b3055d7 --- /dev/null +++ b/colossalai/accelerator/api.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python +from collections import OrderedDict +from typing import Union + +from .base_accelerator import BaseAccelerator +from .cpu_accelerator import CpuAccelerator +from .cuda_accelerator import CudaAccelerator +from .npu_accelerator import NpuAccelerator + +__all__ = ["set_accelerator", "auto_set_accelerator", "get_accelerator"] + + +_ACCELERATOR = None + + +# we use ordered dictionary here to associate the +# order with device check priority +# i.e. auto_set_accelerator will check cuda first +_ACCELERATOR_MAPPING = OrderedDict(cuda=CudaAccelerator, npu=NpuAccelerator, cpu=CpuAccelerator) + + +def set_accelerator(accelerator: Union[str, BaseAccelerator]) -> None: + """ + Set the global accelerator for the current process. + + Args: + accelerator (Union[str, BaseAccelerator]): the type of accelerator to which the current device belongs. + """ + + global _ACCELERATOR + + if isinstance(accelerator, str): + _ACCELERATOR = _ACCELERATOR_MAPPING[accelerator]() + elif isinstance(accelerator, BaseAccelerator): + _ACCELERATOR = accelerator + else: + raise TypeError("accelerator must be either a string or an instance of BaseAccelerator") + + +def auto_set_accelerator() -> None: + """ + Automatically check if any accelerator is available. + If an accelerator is availabe, set it as the global accelerator. + """ + global _ACCELERATOR + + for accelerator_name, accelerator_cls in _ACCELERATOR_MAPPING.items(): + try: + accelerator = accelerator_cls() + if accelerator_name == "cpu" or accelerator.is_available(): + _ACCELERATOR = accelerator + break + except: + pass + + if _ACCELERATOR is None: + raise RuntimeError("No accelerator is available.") + + +def get_accelerator() -> BaseAccelerator: + """ + Return the accelerator for the current process. If the accelerator is not initialized, it will be initialized + to the default accelerator type. + + Returns: the accelerator for the current process. + """ + global _ACCELERATOR + + if _ACCELERATOR is None: + auto_set_accelerator() + return _ACCELERATOR diff --git a/colossalai/accelerator/base_accelerator.py b/colossalai/accelerator/base_accelerator.py new file mode 100644 index 000000000..33c113999 --- /dev/null +++ b/colossalai/accelerator/base_accelerator.py @@ -0,0 +1,320 @@ +#!/usr/bin/env python + +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch + +__all__ = ["BaseAccelerator"] + + +class BaseAccelerator(ABC): + support_set_device: bool = True + + def __init__(self, name: str, communication_backend: str, is_synchronous: bool) -> None: + self._name = name + self._communication_backend = communication_backend + self._is_synchronous = is_synchronous + + # ======================= + # immutable attributes + # ======================= + + @property + def name(self) -> str: + """ + Return the name of the accelerator. + """ + return self._name + + @property + def communication_backend(self) -> str: + """ + Return the name of the backend communication library. + """ + return self._communication_backend + + @property + def is_synchronous(self) -> bool: + """ + Return whether the accelerator is a synchronous device. + """ + return self._is_synchronous + + def __repr__(self) -> str: + cls_name = self.__class__.__name__ + return f"{cls_name}(name={self._name}, communication_backend={self._communication_backend}, is_synchronous={self._is_synchronous})" + + # ======================= + # device APIs + # ======================= + @abstractmethod + def get_version(self) -> str: + """ + Return the version of the accelerator which torch is built against. + """ + + @abstractmethod + def get_current_device(self) -> torch.device: + """ + Return the current device. + """ + + @abstractmethod + def current_device(self) -> int: + """ + Return the current device index. + """ + + @abstractmethod + def set_device(self, device: Optional[Union[torch.device, int]] = None) -> None: + """ + Bind the current process to a device. + """ + + @abstractmethod + def get_device_name(self, device: Union[torch.device, int]) -> str: + """ + Return the name of the device. + """ + + @abstractmethod + def synchronize(self, device: Union[torch.device, int] = None): + """ + Synchronize the current process. + """ + + @abstractmethod + def is_available(self): + """ + Check if the accelerator is available. + """ + + @abstractmethod + def device_count(self): + """ + Return the number of devices on the machine. + """ + + def set_to_device(self, models: Any) -> Any: + """ + Send model to device. + + :param models: nn.module or a list of module + """ + if isinstance(models, list) and len(models) > 1: + ret = [] + for model in models: + ret.append(model.to(self.get_current_device())) + return ret + elif isinstance(models, list): + return models[0].to(self.get_current_device()) + else: + return models.to(self.get_current_device()) + + @abstractmethod + def get_device_capability(self, device=None) -> Tuple[int, int]: + """ + Gets the capability of a device. + """ + + @abstractmethod + def get_device_name(self, device=None) -> str: + """ + Gets the name of a device. + """ + + @abstractmethod + def get_device_properties(self, device): + """ + Gets the properties of a device. + """ + + @abstractmethod + def utilization(self, device=None) -> int: + """ + Returns the percent of time over the past sample period during which one or more kernels was executing on the device as given by nvidia-smi or npu-smi, etc. + """ + + # ======================= + # random number generator APIs + # ======================= + @abstractmethod + def get_rng_state(self, device="cuda") -> torch.Tensor: + """ + Returns the random number generator state of the specified device as a ByteTensor. + """ + + @abstractmethod + def get_rng_state_all(self) -> List[torch.Tensor]: + """ + Returns a list of ByteTensor representing the random number states of all devices. + """ + + @abstractmethod + def set_rng_state(self, new_state: torch.ByteTensor, device: str = "cuda") -> None: + """ + Sets the random number generator state of the specified device. + """ + + @abstractmethod + def set_rng_state_all(self, new_states: List[torch.ByteTensor]) -> None: + """ + Sets the random number generator state of all devices. + """ + + @abstractmethod + def manual_seed(self, seed: int) -> None: + """ + Sets the seed for generating random numbers for the current device. + """ + + @abstractmethod + def manual_seed_all(self, seed: int) -> None: + """ + Sets the seed for generating random numbers on all devices. + """ + + @abstractmethod + def seed(self) -> None: + """ + Sets the seed for generating random numbers to a random number for the current device. + """ + + @abstractmethod + def seed_all(self) -> None: + """ + Sets the seed for generating random numbers to a random number on all devices. + """ + + @abstractmethod + def initial_seed(self) -> int: + """ + Returns the current random seed of the current device. + """ + + # ======================= + # memory management APIs + # ======================= + @abstractmethod + def empty_cache(self) -> None: + """ + Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other device application and visible in nvidia-smi. + """ + + @abstractmethod + def memory_stats(self, device=None) -> Dict[str, Any]: + """ + Returns a dictionary of CUDA memory allocator statistics for a given device. + """ + + @abstractmethod + def memory_summary(self, device=None, abbreviated=False) -> str: + """ + Returns a human-readable printout of the current memory allocator statistics for a given device. + """ + + @abstractmethod + def memory_snapshot(self): + """ + Returns a snapshot of the CUDA memory allocator state across all devices. + """ + + @abstractmethod + def memory_allocated(self, device=None) -> int: + """ + Returns the current device memory occupied by tensors in bytes for a given device. + """ + + @abstractmethod + def max_memory_allocated(self, device=None) -> int: + """ + Returns the maximum device memory occupied by tensors in bytes for a given device. + """ + + @abstractmethod + def reset_max_memory_allocated(self, device=None) -> None: + """ + Resets the starting point in tracking maximum device memory occupied by tensors for a given device. + """ + + @abstractmethod + def reset_max_memory_cached(self, device=None) -> None: + """ + Resets the starting point in tracking maximum device memory managed by the caching allocator for a given device. + """ + + @abstractmethod + def memory_reserved(self, device=None) -> int: + """ + Returns the current device memory managed by the caching allocator in bytes for a given device. + """ + + @abstractmethod + def max_memory_reserved(self, device=None) -> int: + """ + Returns the maximum device memory managed by the caching allocator in bytes for a given device. + """ + + @abstractmethod + def set_per_process_memory_fraction(self, fraction: float, device=None) -> None: + """ + Set memory fraction for a process. + """ + + @abstractmethod + def reset_peak_memory_stats(self, device=None) -> None: + """ + Resets the "peak" stats tracked by the device memory allocator. + """ + + # ======================= + # streams and events APIs + # ======================= + + @abstractmethod + def Stream(self, device=None, priority=0, **kwargs): + """ + A device stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See cuda-semantics for details. + """ + + @abstractmethod + def Event(self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False): + """ + device events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize CUDA streams. + """ + + @abstractmethod + def current_stream(self, device=None): + """ + Returns the currently selected Stream for a given device. + """ + + @abstractmethod + def default_stream(self, device=None): + """ + Returns the default Stream for a given device. + """ + + @abstractmethod + def set_stream(self, stream_): + """ + Sets the current stream.This is a wrapper API to set the stream. + """ + + @abstractmethod + def stream(self, stream_): + """ + Wrapper around the Context-manager StreamContext that selects a given stream. + """ + + # ======================= + # amp APIs + # ======================= + @abstractmethod + def autocast( + self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True + ) -> Callable: + """ + Return autocast function + """ diff --git a/colossalai/accelerator/cpu_accelerator.py b/colossalai/accelerator/cpu_accelerator.py new file mode 100644 index 000000000..080aa61e8 --- /dev/null +++ b/colossalai/accelerator/cpu_accelerator.py @@ -0,0 +1,283 @@ +#!/usr/bin/env python + +import resource +from contextlib import nullcontext +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import psutil +import torch + +from .base_accelerator import BaseAccelerator + +__all__ = ["CpuAccelerator"] + + +class CpuAccelerator(BaseAccelerator): + support_set_device: bool = False + """ + Accelerator class for cpu. + """ + + def __init__(self): + super().__init__(name="cpu", communication_backend="gloo", is_synchronous=False) + + # ======================= + # device APIs + # ======================= + def get_version(self) -> str: + """ + Return the version of the accelerator which torch is built against. + """ + return "" + + def get_current_device(self) -> torch.device: + """ + Return the current device. + """ + return torch.device("cpu") + + def current_device(self) -> int: + """ + Return the current device index. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def set_device(self, device: Optional[Union[torch.device, int]] = None) -> None: + """ + Bind the current process to a device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def get_device_name(self, device: Union[torch.device, int]) -> str: + """ + Return the name of the device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def synchronize(self, device: Union[torch.device, int] = None): + """ + Synchronize the current process. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def is_available(self): + """ + Check if the accelerator is available. + """ + return True + + def device_count(self): + """ + Return the number of devices on the machine. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def get_device_capability(self, device=None) -> Tuple[int, int]: + """ + Gets the cuda capability of a device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def get_device_name(self, device=None) -> str: + """ + Gets the name of a device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def get_device_properties(self, device): + """ + Gets the properties of a device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def utilization(self, device=None) -> int: + """ + Returns the percent of time over the past sample period during which one or more kernels was executing on the GPU as given by nvidia-smi + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + # ======================= + # random number generator APIs + # ======================= + def get_rng_state(self, device=None) -> torch.Tensor: + """ + Returns the random number generator state of the specified GPU as a ByteTensor. + """ + return torch.get_rng_state(device) + + def get_rng_state_all(self) -> List[torch.Tensor]: + """ + Returns a list of ByteTensor representing the random number states of all devices. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def set_rng_state(self, new_state: torch.ByteTensor, device: str = None) -> None: + """ + Sets the random number generator state of the specified GPU. + """ + torch.set_rng_state(new_state) + + def set_rng_state_all(self, new_states: List[torch.ByteTensor]) -> None: + """ + Sets the random number generator state of all devices. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def manual_seed(self, seed: int) -> None: + """ + Sets the seed for generating random numbers for the current GPU. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def manual_seed_all(self, seed: int) -> None: + """ + Set the random seed for the all processes. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def seed(self) -> None: + """ + Sets the seed for generating random numbers to a random number for the current GPU. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def seed_all(self) -> None: + """ + Sets the seed for generating random numbers to a random number on all GPUs. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def initial_seed(self) -> int: + """ + Returns the current random seed of the current GPU. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + # ======================= + # memory management APIs + # ======================= + + def empty_cache(self) -> None: + """ + Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other GPU application and visible in nvidia-smi. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def memory_stats(self, device=None) -> Dict[str, Any]: + """ + Returns a dictionary of CUDA memory allocator statistics for a given device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def memory_summary(self, device=None, abbreviated=False) -> str: + """ + Returns a human-readable printout of the current memory allocator statistics for a given device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def memory_snapshot(self): + """ + Returns a snapshot of the CUDA memory allocator state across all devices. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def memory_allocated(self, device=None) -> int: + """ + Returns the current GPU memory occupied by tensors in bytes for a given device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def max_memory_allocated(self, device=None) -> int: + """ + Returns the maximum GPU memory occupied by tensors in bytes for a given device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def reset_max_memory_allocated(self, device=None) -> None: + """ + Resets the starting point in tracking maximum GPU memory occupied by tensors for a given device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def reset_max_memory_cached(self, device=None) -> None: + """ + Resets the starting point in tracking maximum GPU memory managed by the caching allocator for a given device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def memory_reserved(self, device=None) -> int: + """ + Returns the current GPU memory managed by the caching allocator in bytes for a given device. + """ + return psutil.Process().memory_info().rss + + def max_memory_reserved(self, device=None) -> int: + """ + Returns the maximum GPU memory managed by the caching allocator in bytes for a given device. + """ + return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + + def set_per_process_memory_fraction(self, fraction: float, device=None) -> None: + """ + Set memory fraction for a process. + """ + max_memory = int(psutil.virtual_memory().total * fraction) + _, hard = resource.getrlimit(resource.RLIMIT_AS) + resource.setrlimit(resource.RLIMIT_AS, (max_memory, hard)) + + def reset_peak_memory_stats(self, device=None) -> None: + """ + Resets the "peak" stats tracked by the CUDA memory allocator. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + # ======================= + # streams and events APIs + # ======================= + + def Stream(self, device=None, priority=0, **kwargs): + """ + A CUDA stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See cuda-semantics for details. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def Event(self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False): + """ + CUDA events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize CUDA streams. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def current_stream(self, device=None): + """ + Returns the currently selected Stream for a given device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def default_stream(self, device=None): + """ + Returns the default Stream for a given device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def set_stream(self, stream_): + """ + Sets the current stream.This is a wrapper API to set the stream. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def stream(self, stream_): + """ + Wrapper around the Context-manager StreamContext that selects a given stream. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + # ======================= + # amp APIs + # ======================= + def autocast( + self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True + ) -> Callable: + """ + Return autocast function + """ + return nullcontext diff --git a/colossalai/accelerator/cuda_accelerator.py b/colossalai/accelerator/cuda_accelerator.py new file mode 100644 index 000000000..f1ab487d4 --- /dev/null +++ b/colossalai/accelerator/cuda_accelerator.py @@ -0,0 +1,282 @@ +#!/usr/bin/env python + +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.distributed as dist + +from .base_accelerator import BaseAccelerator + +__all__ = ["CudaAccelerator"] + + +class CudaAccelerator(BaseAccelerator): + """ + Accelerator class for Nvidia CUDA devices. + """ + + def __init__(self): + super().__init__(name="cuda", communication_backend="nccl", is_synchronous=False) + + # ======================= + # device APIs + # ======================= + def get_version(self) -> str: + """ + Return the version of the accelerator which torch is built against. + """ + return torch.version.cuda + + def get_current_device(self) -> torch.device: + """ + Return the current device. + """ + return torch.device(f"cuda:{torch.cuda.current_device()}") + + def current_device(self) -> int: + """ + Return the current device index. + """ + return torch.cuda.current_device() + + def set_device(self, device: Optional[Union[torch.device, int]] = None) -> None: + """ + Bind the current process to a device. + """ + if device is None: + if not dist.is_initialized(): + raise RuntimeError("Cannot get current device when distributed is not initialized.") + device = dist.get_rank() % self.device_count() + torch.cuda.set_device(device) + + def get_device_name(self, device: Union[torch.device, int]) -> str: + """ + Return the name of the device. + """ + return torch.cuda.get_device_name(device) + + def synchronize(self, device: Union[torch.device, int] = None): + """ + Synchronize the current process. + """ + torch.cuda.synchronize(device) + + def is_available(self): + """ + Check if the accelerator is available. + """ + return torch.cuda.is_available() + + def device_count(self): + """ + Return the number of devices on the machine. + """ + return torch.cuda.device_count() + + def get_device_capability(self, device=None) -> Tuple[int, int]: + """ + Gets the cuda capability of a device. + """ + return torch.cuda.get_device_capability(device) + + def get_device_name(self, device=None) -> str: + """ + Gets the name of a device. + """ + return torch.cuda.get_device_name(device) + + def get_device_properties(self, device): + """ + Gets the properties of a device. + """ + return torch.cuda.get_device_properties(device) + + def utilization(self, device=None) -> int: + """ + Returns the percent of time over the past sample period during which one or more kernels was executing on the GPU as given by nvidia-smi + """ + return torch.cuda.utilization(device) + + # ======================= + # random number generator APIs + # ======================= + def get_rng_state(self, device="cuda") -> torch.Tensor: + """ + Returns the random number generator state of the specified GPU as a ByteTensor. + """ + return torch.cuda.get_rng_state(device) + + def get_rng_state_all(self) -> List[torch.Tensor]: + """ + Returns a list of ByteTensor representing the random number states of all devices. + """ + return torch.cuda.get_rng_state_all() + + def set_rng_state(self, new_state: torch.ByteTensor, device: str = "cuda") -> None: + """ + Sets the random number generator state of the specified GPU. + """ + torch.cuda.set_rng_state(new_state, device) + + def set_rng_state_all(self, new_states: List[torch.ByteTensor]) -> None: + """ + Sets the random number generator state of all devices. + """ + torch.cuda.set_rng_state_all(new_states) + + def manual_seed(self, seed: int) -> None: + """ + Sets the seed for generating random numbers for the current GPU. + """ + torch.cuda.manual_seed(seed) + + def manual_seed_all(self, seed: int) -> None: + """ + Set the random seed for the all processes. + """ + torch.cuda.manual_seed_all(seed) + + def seed(self) -> None: + """ + Sets the seed for generating random numbers to a random number for the current GPU. + """ + torch.cuda.seed() + + def seed_all(self) -> None: + """ + Sets the seed for generating random numbers to a random number on all GPUs. + """ + torch.cuda.seed_all() + + def initial_seed(self) -> int: + """ + Returns the current random seed of the current GPU. + """ + return torch.cuda.initial_seed() + + # ======================= + # memory management APIs + # ======================= + + def empty_cache(self) -> None: + """ + Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other GPU application and visible in nvidia-smi. + """ + torch.cuda.empty_cache() + + def memory_stats(self, device=None) -> Dict[str, Any]: + """ + Returns a dictionary of CUDA memory allocator statistics for a given device. + """ + return torch.cuda.memory_stats(device=device) + + def memory_summary(self, device=None, abbreviated=False) -> str: + """ + Returns a human-readable printout of the current memory allocator statistics for a given device. + """ + return torch.cuda.memory_summary(device=device, abbreviated=abbreviated) + + def memory_snapshot(self): + """ + Returns a snapshot of the CUDA memory allocator state across all devices. + """ + return torch.cuda.memory_snapshot() + + def memory_allocated(self, device=None) -> int: + """ + Returns the current GPU memory occupied by tensors in bytes for a given device. + """ + return torch.cuda.memory_allocated(device=device) + + def max_memory_allocated(self, device=None) -> int: + """ + Returns the maximum GPU memory occupied by tensors in bytes for a given device. + """ + return torch.cuda.max_memory_allocated(device=device) + + def reset_max_memory_allocated(self, device=None) -> None: + """ + Resets the starting point in tracking maximum GPU memory occupied by tensors for a given device. + """ + torch.cuda.reset_max_memory_allocated(device=device) + + def reset_max_memory_cached(self, device=None) -> None: + """ + Resets the starting point in tracking maximum GPU memory managed by the caching allocator for a given device. + """ + torch.cuda.reset_max_memory_cached(device=device) + + def memory_reserved(self, device=None) -> int: + """ + Returns the current GPU memory managed by the caching allocator in bytes for a given device. + """ + return torch.cuda.memory_reserved(device=device) + + def max_memory_reserved(self, device=None) -> int: + """ + Returns the maximum GPU memory managed by the caching allocator in bytes for a given device. + """ + return torch.cuda.max_memory_reserved(device=device) + + def set_per_process_memory_fraction(self, fraction: float, device=None) -> None: + """ + Set memory fraction for a process. + """ + torch.cuda.set_per_process_memory_fraction(fraction, device=device) + + def reset_peak_memory_stats(self, device=None) -> None: + """ + Resets the "peak" stats tracked by the CUDA memory allocator. + """ + torch.cuda.reset_peak_memory_stats(device=device) + + # ======================= + # streams and events APIs + # ======================= + + def Stream(self, device=None, priority=0, **kwargs): + """ + A CUDA stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See cuda-semantics for details. + """ + return torch.cuda.Stream(device, priority, **kwargs) + + def Event(self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False): + """ + CUDA events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize CUDA streams. + """ + return torch.cuda.Event(enable_timing, blocking, interprocess) + + def current_stream(self, device=None): + """ + Returns the currently selected Stream for a given device. + """ + return torch.cuda.current_stream(device) + + def default_stream(self, device=None): + """ + Returns the default Stream for a given device. + """ + return torch.cuda.default_stream(device) + + def set_stream(self, stream_): + """ + Sets the current stream.This is a wrapper API to set the stream. + """ + torch.cuda.set_stream(stream_) + + def stream(self, stream_): + """ + Wrapper around the Context-manager StreamContext that selects a given stream. + """ + return torch.cuda.stream(stream_) + + # ======================= + # amp APIs + # ======================= + def autocast( + self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True + ) -> Callable: + """ + Return autocast function + """ + return torch.cuda.amp.autocast(enabled=enabled, dtype=dtype, cache_enabled=cache_enabled) diff --git a/colossalai/accelerator/npu_accelerator.py b/colossalai/accelerator/npu_accelerator.py new file mode 100644 index 000000000..b28492968 --- /dev/null +++ b/colossalai/accelerator/npu_accelerator.py @@ -0,0 +1,288 @@ +#!/usr/bin/env python + +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.distributed as dist + +from .base_accelerator import BaseAccelerator + +try: + import torch_npu # noqa +except ImportError: + pass + + +__all__ = ["NpuAccelerator"] + + +class NpuAccelerator(BaseAccelerator): + """ + Accelerator class for Huawei NPU devices. + """ + + def __init__(self): + super().__init__(name="npu", communication_backend="hccl", is_synchronous=False) + + # ======================= + # device APIs + # ======================= + def get_version(self) -> str: + """ + Return the version of the accelerator which torch is built against. + """ + return torch.version.cann + + def get_current_device(self) -> torch.device: + """ + Return the current device. + """ + return torch.device(f"npu:{torch.npu.current_device()}") + + def current_device(self) -> int: + """ + Return the current device index. + """ + return torch.npu.current_device() + + def set_device(self, device: Optional[Union[torch.device, int]] = None) -> None: + """ + Bind the current process to a device. + """ + if device is None: + if not dist.is_initialized(): + raise RuntimeError("Cannot get current device when distributed is not initialized.") + device = dist.get_rank() % self.device_count() + torch.npu.set_device(device) + + def get_device_name(self, device: Union[torch.device, int]) -> str: + """ + Return the name of the device. + """ + return torch.npu.get_device_name(device) + + def synchronize(self, device: Union[torch.device, int] = None): + """ + Synchronize the current process. + """ + torch.npu.synchronize(device) + + def is_available(self): + """ + Check if the accelerator is available. + """ + return torch.npu.is_available() + + def device_count(self): + """ + Return the number of devices on the machine. + """ + return torch.npu.device_count() + + def get_device_capability(self, device=None) -> Tuple[int, int]: + """ + Gets the npu capability of a device. + """ + return torch.npu.get_device_capability(device) + + def get_device_name(self, device=None) -> str: + """ + Gets the name of a device. + """ + return torch.npu.get_device_name(device) + + def get_device_properties(self, device): + """ + Gets the properties of a device. + """ + return torch.npu.get_device_properties(device) + + def utilization(self, device=None) -> int: + """ + Returns the percent of time over the past sample period during which one or more kernels was executing on the GPU as given by nvidia-smi + """ + return torch.npu.utilization(device) + + # ======================= + # random number generator APIs + # ======================= + def get_rng_state(self, device="npu") -> torch.Tensor: + """ + Returns the random number generator state of the specified GPU as a ByteTensor. + """ + return torch.npu.get_rng_state(device) + + def get_rng_state_all(self) -> List[torch.Tensor]: + """ + Returns a list of ByteTensor representing the random number states of all devices. + """ + return torch.npu.get_rng_state_all() + + def set_rng_state(self, new_state: torch.ByteTensor, device: str = "npu") -> None: + """ + Sets the random number generator state of the specified GPU. + """ + torch.npu.set_rng_state(new_state, device) + + def set_rng_state_all(self, new_states: List[torch.ByteTensor]) -> None: + """ + Sets the random number generator state of all devices. + """ + torch.npu.set_rng_state_all(new_states) + + def manual_seed(self, seed: int) -> None: + """ + Sets the seed for generating random numbers for the current GPU. + """ + torch.npu.manual_seed(seed) + + def manual_seed_all(self, seed: int) -> None: + """ + Set the random seed for the all processes. + """ + torch.npu.manual_seed_all(seed) + + def seed(self) -> None: + """ + Sets the seed for generating random numbers to a random number for the current GPU. + """ + torch.npu.seed() + + def seed_all(self) -> None: + """ + Sets the seed for generating random numbers to a random number on all GPUs. + """ + torch.npu.seed_all() + + def initial_seed(self) -> int: + """ + Returns the current random seed of the current GPU. + """ + return torch.npu.initial_seed() + + # ======================= + # memory management APIs + # ======================= + + def empty_cache(self) -> None: + """ + Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other GPU application and visible in nvidia-smi. + """ + torch.npu.empty_cache() + + def memory_stats(self, device=None) -> Dict[str, Any]: + """ + Returns a dictionary of npu memory allocator statistics for a given device. + """ + return torch.npu.memory_stats(device=device) + + def memory_summary(self, device=None, abbreviated=False) -> str: + """ + Returns a human-readable printout of the current memory allocator statistics for a given device. + """ + return torch.npu.memory_summary(device=device, abbreviated=abbreviated) + + def memory_snapshot(self): + """ + Returns a snapshot of the npu memory allocator state across all devices. + """ + return torch.npu.memory_snapshot() + + def memory_allocated(self, device=None) -> int: + """ + Returns the current GPU memory occupied by tensors in bytes for a given device. + """ + return torch.npu.memory_allocated(device=device) + + def max_memory_allocated(self, device=None) -> int: + """ + Returns the maximum GPU memory occupied by tensors in bytes for a given device. + """ + return torch.npu.max_memory_allocated(device=device) + + def reset_max_memory_allocated(self, device=None) -> None: + """ + Resets the starting point in tracking maximum GPU memory occupied by tensors for a given device. + """ + torch.npu.reset_max_memory_allocated(device=device) + + def reset_max_memory_cached(self, device=None) -> None: + """ + Resets the starting point in tracking maximum GPU memory managed by the caching allocator for a given device. + """ + torch.npu.reset_max_memory_cached(device=device) + + def memory_reserved(self, device=None) -> int: + """ + Returns the current GPU memory managed by the caching allocator in bytes for a given device. + """ + return torch.npu.memory_reserved(device=device) + + def max_memory_reserved(self, device=None) -> int: + """ + Returns the maximum GPU memory managed by the caching allocator in bytes for a given device. + """ + return torch.npu.max_memory_reserved(device=device) + + def set_per_process_memory_fraction(self, fraction: float, device=None) -> None: + """ + Set memory fraction for a process. + """ + torch.npu.set_per_process_memory_fraction(fraction, device=device) + + def reset_peak_memory_stats(self, device=None) -> None: + """ + Resets the "peak" stats tracked by the npu memory allocator. + """ + torch.npu.reset_peak_memory_stats(device=device) + + # ======================= + # streams and events APIs + # ======================= + + def Stream(self, device=None, priority=0, **kwargs): + """ + A npu stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See npu-semantics for details. + """ + return torch.npu.Stream(device, priority, **kwargs) + + def Event(self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False): + """ + npu events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize npu streams. + """ + return torch.npu.Event(enable_timing, blocking, interprocess) + + def current_stream(self, device=None): + """ + Returns the currently selected Stream for a given device. + """ + return torch.npu.current_stream(device) + + def default_stream(self, device=None): + """ + Returns the default Stream for a given device. + """ + return torch.npu.default_stream(device) + + def set_stream(self, stream_): + """ + Sets the current stream.This is a wrapper API to set the stream. + """ + torch.npu.set_stream(stream_) + + def stream(self, stream_): + """ + Wrapper around the Context-manager StreamContext that selects a given stream. + """ + return torch.npu.stream(stream_) + + # ======================= + # amp APIs + # ======================= + def autocast( + self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True + ) -> Callable: + """ + Return autocast function + """ + return torch.npu.amp.autocast(enabled=enabled, dtype=dtype, cache_enabled=cache_enabled) diff --git a/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py index 439d13dcf..fc4c884d4 100644 --- a/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py +++ b/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py @@ -7,8 +7,8 @@ from typing import Dict import torch from torch import Tensor +from colossalai.accelerator import get_accelerator from colossalai.logging import get_dist_logger -from colossalai.utils.device import get_current_device __all__ = ["BaseGradScaler"] @@ -23,7 +23,7 @@ class BaseGradScaler(ABC): def __init__(self, initial_scale: float, verbose: bool): assert initial_scale > 0 - self._scale = torch.tensor([initial_scale], device=get_current_device(), dtype=torch.float) + self._scale = torch.tensor([initial_scale], device=get_accelerator().get_current_device(), dtype=torch.float) self._verbose = verbose if self._verbose: diff --git a/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py index 86ba919ee..5cd8035d7 100644 --- a/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py +++ b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py @@ -5,7 +5,7 @@ from typing import Optional import torch -from colossalai.utils.device import get_current_device +from colossalai.accelerator import get_accelerator from .base_grad_scaler import BaseGradScaler @@ -37,14 +37,20 @@ class DynamicGradScaler(BaseGradScaler): hysteresis: int = 2, verbose: bool = False, ): + a = get_accelerator() + a.device_count() super().__init__(initial_scale, verbose) if min_scale: - self._min_scale = torch.tensor([min_scale], device=get_current_device(), dtype=torch.float) + self._min_scale = torch.tensor( + [min_scale], device=get_accelerator().get_current_device(), dtype=torch.float + ) else: self._min_scale = None if max_scale: - self._max_scale = torch.tensor([max_scale], device=get_current_device(), dtype=torch.float) + self._max_scale = torch.tensor( + [max_scale], device=get_accelerator().get_current_device(), dtype=torch.float + ) else: self._max_scale = None @@ -117,7 +123,7 @@ class DynamicGradScaler(BaseGradScaler): return state_dict def load_state_dict(self, state_dict): - self._scale = state_dict["scale"].to(get_current_device()) + self._scale = state_dict["scale"].to(get_accelerator().get_current_device()) self._growth_factor = state_dict["growth_factor"] self._backoff_factor = state_dict["backoff_factor"] self._hysteresis = state_dict["hysteresis"] diff --git a/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py b/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py index 9ce272356..2e7c8a281 100644 --- a/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py +++ b/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py @@ -5,8 +5,8 @@ import torch import torch.distributed as dist from torch import Tensor +from colossalai.accelerator import get_accelerator from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler -from colossalai.utils import get_current_device from .base import MixedPrecisionMixin @@ -40,7 +40,7 @@ class FP16MixedPrecisionMixin(MixedPrecisionMixin): max_scale=max_scale, ) self.optim_state = OptimState.UNSCALED - self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_current_device()) + self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_accelerator().get_current_device()) @property def loss_scale(self) -> float: diff --git a/colossalai/auto_parallel/offload/amp_optimizer.py b/colossalai/auto_parallel/offload/amp_optimizer.py index 601bf2926..fe8439269 100644 --- a/colossalai/auto_parallel/offload/amp_optimizer.py +++ b/colossalai/auto_parallel/offload/amp_optimizer.py @@ -4,10 +4,10 @@ from typing import Dict, Tuple import torch from torch.optim import Optimizer +from colossalai.accelerator import get_accelerator from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler from colossalai.interface import OptimizerWrapper from colossalai.logging import get_dist_logger -from colossalai.utils import get_current_device from .base_offload_module import BaseOffloadModule from .region import Region @@ -79,7 +79,9 @@ class AMPOptimizer(OptimizerWrapper): hysteresis=hysteresis, max_scale=max_scale, ) - self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device()) + self._found_overflow: torch.Tensor = torch.zeros( + 1, dtype=torch.int64, device=get_accelerator().get_current_device() + ) self._logger = get_dist_logger() def _set_grad_ptr(self): diff --git a/colossalai/auto_parallel/offload/solver.py b/colossalai/auto_parallel/offload/solver.py index a6628e29c..3ad210de9 100644 --- a/colossalai/auto_parallel/offload/solver.py +++ b/colossalai/auto_parallel/offload/solver.py @@ -11,7 +11,7 @@ except: import torch from torch.fx.node import Node -from colossalai.utils.device import get_current_device +from colossalai.accelerator import get_accelerator from .region import Region from .training_simulator import AsynTrainingSimulator, SynTrainingSimulator, TrainingSimulator @@ -57,7 +57,10 @@ class Solver(ABC): if memory_budget > 0: self.memory_budget = memory_budget * self.error_factor else: - self.memory_budget = torch.cuda.get_device_properties(get_current_device()).total_memory * self.error_factor + self.memory_budget = ( + torch.cuda.get_device_properties(get_accelerator().get_current_device()).total_memory + * self.error_factor + ) self.link_to_bandwidth: Dict[str, Dict[float, float]] = self._profile_bandwidth() self.comp_power: float = self._extract_computing_power() diff --git a/colossalai/booster/mixed_precision/fp16_torch.py b/colossalai/booster/mixed_precision/fp16_torch.py index 443c4094c..c757a878d 100644 --- a/colossalai/booster/mixed_precision/fp16_torch.py +++ b/colossalai/booster/mixed_precision/fp16_torch.py @@ -5,8 +5,8 @@ import torch.nn as nn from torch import Tensor from torch.optim import Optimizer +from colossalai.accelerator import get_accelerator from colossalai.interface import ModelWrapper, OptimizerWrapper -from colossalai.utils.device import autocast from .mixed_precision_base import MixedPrecision @@ -89,7 +89,7 @@ class TorchAMPModule(ModelWrapper): super().__init__(module) def forward(self, *args, **kwargs): - with autocast(): + with get_accelerator().autocast(): return self.module(*args, **kwargs) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index a891db422..d14109dd4 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -15,6 +15,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler +from colossalai.accelerator import get_accelerator from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO from colossalai.checkpoint_io.utils import ( get_model_base_filenames, @@ -27,8 +28,6 @@ from colossalai.checkpoint_io.utils import ( from colossalai.cluster import DistCoordinator, ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.shardformer import ShardConfig, ShardFormer -from colossalai.utils import get_current_device -from colossalai.utils.device import IS_NPU_AVAILABLE from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.memory_tracer import MemStats @@ -366,11 +365,11 @@ class GeminiPlugin(DPPluginBase): ) -> None: super().__init__() assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported" - if IS_NPU_AVAILABLE: + if get_accelerator().name == "npu": assert placement_policy == "static", "NPU only supports static placement policy" self.gemini_config = dict( chunk_config_dict=chunk_config_dict, - chunk_init_device=(chunk_init_device or get_current_device()), + chunk_init_device=(chunk_init_device or get_accelerator().get_current_device()), placement_policy=placement_policy, enable_gradient_accumulation=enable_gradient_accumulation, shard_param_frac=shard_param_frac, @@ -455,7 +454,7 @@ class GeminiPlugin(DPPluginBase): def supported_devices(self) -> List[str]: return ["cuda", "npu"] - + def prepare_dataloader( self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs ): @@ -486,7 +485,10 @@ class GeminiPlugin(DPPluginBase): zero_rank = self.pg_mesh.coordinate(ZERO_AXIS) extra_dp_rank = self.pg_mesh.coordinate(DP_AXIS) sampler = DistributedSampler( - dataset, num_replicas=zero_world_size * extra_dp_world_size, rank=zero_rank * extra_dp_world_size + extra_dp_rank, shuffle=shuffle + dataset, + num_replicas=zero_world_size * extra_dp_world_size, + rank=zero_rank * extra_dp_world_size + extra_dp_rank, + shuffle=shuffle, ) # Deterministic dataloader diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 8ee1e97c6..5837156a9 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -18,6 +18,7 @@ from torch.utils._pytree import tree_map from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler +from colossalai.accelerator import get_accelerator from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO from colossalai.cluster import ProcessGroupMesh @@ -28,7 +29,6 @@ from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer.layer.utils import SeqParallelUtils from colossalai.shardformer.policies.base_policy import Policy from colossalai.tensor.d_tensor.api import is_distributed_tensor -from colossalai.utils.device import get_current_device from colossalai.zero.low_level import LowLevelZeroOptimizer from .pp_plugin_base import PipelinePluginBase @@ -82,7 +82,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): self.mixed_precision = torch.bfloat16 if self.mixed_precision is not None: module = module.to(self.mixed_precision) - module = module.to(get_current_device()) + module = module.to(get_accelerator().get_current_device()) # setting input type cast when using mixed precision self.convert_fn = None @@ -165,7 +165,6 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): Returns: None """ - if self.tp_group.size() > 1 and self.shard_config.enable_sequence_parallelism: if grads is not None: # Synchronize provided gradient tensors across the tensor parallelism group. @@ -346,7 +345,9 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper): if norm_type == inf: total_norm = max(grad.data.abs().max() for grad in gradients) - total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32) + total_norm_cuda = torch.tensor( + [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32 + ) if self.tp_size > 1: dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) if self.pp_size > 1: @@ -386,7 +387,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper): total_norm_exponentiated += grad_norm_exponentiated total_norm_exponentiated_cuda = torch.tensor( - [float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32 + [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32 ) if self.tp_size > 1: # compute norm in tp process group @@ -487,7 +488,6 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): Returns: None """ - # Call the superclass backward method to compute gradients. super().backward(loss, *args, **kwargs) @@ -513,7 +513,6 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): Returns: None """ - # Call the superclass backward method to compute gradients. super().backward_by_grad(tensor, grad) @@ -545,7 +544,9 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): # so we need to calculate the norm of 'tp' and 'pp' gradients. total_norm = super()._compute_grad_norm(param_gradient_pairs, norm_type) - total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32) + total_norm_cuda = torch.tensor( + [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32 + ) if self.tp_size > 1: dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) @@ -589,7 +590,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): total_norm_exponentiated += grad_norm_exponentiated total_norm_exponentiated_cuda = torch.tensor( - [float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32 + [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32 ) if self.tp_size > 1: # compute norm in tp process group @@ -674,7 +675,6 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): Returns: None """ - # Call the superclass `_sync_grad` method to synchronize gradients. super()._sync_grad() @@ -802,7 +802,9 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): # so we only need to calculate the norm 'tp' of 'pp' gradients. total_norm = super()._compute_grad_norm(gradients, norm_type) - total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32) + total_norm_cuda = torch.tensor( + [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32 + ) if tp_size > 1: dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) @@ -842,7 +844,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): total_norm_exponentiated += grad_norm_exponentiated total_norm_exponentiated_cuda = torch.tensor( - [float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32 + [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32 ) if dp_size > 1: # compute norm in dp process group @@ -1081,7 +1083,7 @@ class HybridParallelPlugin(PipelinePluginBase): return True def support_no_sync(self) -> bool: - return False + return True def control_checkpoint_io(self) -> bool: return True @@ -1175,9 +1177,14 @@ class HybridParallelPlugin(PipelinePluginBase): model, data_iter, criterion, optimizer, return_loss, return_outputs ) + # run with gradients accumulation + if model.require_grad_sync == False or ( + isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False + ): + return outputs + # Synchronize the grads of shared parameters of the model. model.sync_shared_params() - # Synchronize sequence parallelism gradients of the model. model.sync_sp_grads() @@ -1241,5 +1248,8 @@ class HybridParallelPlugin(PipelinePluginBase): def get_checkpoint_io(self) -> CheckpointIO: return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) - def no_sync(self, model: Module) -> Iterator[None]: - raise NotImplementedError + def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]: + assert ( + self.zero_stage != 2 + ), "ZERO2 is not compatible with no_sync function, please run gradient accumulation with gradient synchronization allowed." + return optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 89102820c..d21496f0b 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -12,6 +12,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils._pytree import tree_map from torch.utils.data import DataLoader +from colossalai.accelerator import get_accelerator from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO from colossalai.checkpoint_io.utils import ( get_optimizer_base_filenames, @@ -24,7 +25,6 @@ from colossalai.checkpoint_io.utils import ( sharded_optimizer_loading_epilogue, ) from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper -from colossalai.utils import get_current_device from colossalai.zero import LowLevelZeroOptimizer from .dp_plugin_base import DPPluginBase @@ -52,7 +52,7 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin): self.dtype = torch.bfloat16 if self.dtype is not None: module = module.to(self.dtype) - module = module.to(get_current_device()) + module = module.to(get_accelerator().get_current_device()) self.module = module self.convert_fn = None if self.dtype is not None: diff --git a/colossalai/initialize.py b/colossalai/initialize.py index 25076b742..aaeaad382 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -6,12 +6,12 @@ import warnings from pathlib import Path from typing import Dict, Union -import torch import torch.distributed as dist +from colossalai.accelerator import get_accelerator from colossalai.context import Config from colossalai.logging import get_dist_logger -from colossalai.utils import IS_NPU_AVAILABLE, set_device, set_seed +from colossalai.utils import set_seed def launch( @@ -47,17 +47,18 @@ def launch( if rank == 0: warnings.warn("`config` is deprecated and will be removed soon.") - if IS_NPU_AVAILABLE and backend == "nccl": - backend = "hccl" + cur_accelerator = get_accelerator() + + backend = cur_accelerator.communication_backend # init default process group init_method = f"tcp://[{host}]:{port}" dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method) # set cuda device - if torch.cuda.is_available() or IS_NPU_AVAILABLE: - # if local rank is not given, calculate automatically - set_device(local_rank) + # if local rank is not given, calculate automatically + if cur_accelerator.support_set_device: + cur_accelerator.set_device(local_rank) set_seed(seed) diff --git a/colossalai/kernel/__init__.py b/colossalai/kernel/__init__.py index 8933fc0a3..e69de29bb 100644 --- a/colossalai/kernel/__init__.py +++ b/colossalai/kernel/__init__.py @@ -1,7 +0,0 @@ -from .cuda_native import FusedScaleMaskSoftmax, LayerNorm, MultiHeadAttention - -__all__ = [ - "LayerNorm", - "FusedScaleMaskSoftmax", - "MultiHeadAttention", -] diff --git a/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cu b/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cu deleted file mode 100644 index 2b1b366b1..000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cu +++ /dev/null @@ -1,63 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#include "column_remap.cuh" -#include "util.cuh" - -const int SHUF_BLOCKSIZE_X = 256; -const int SHUF_BLOCKSIZE_Y = 16; - -__global__ void column_remap_kernel -( - const half* __restrict__ x, - half* __restrict__ x_new, - const int x_width, - const int x_height, - const uint32_t* x_map -) -{ - int x_column = SHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x; - int x_row = SHUF_BLOCKSIZE_Y * blockIdx.y; - if (x_column >= x_width) return; - //if (x_row >= x_height) return; - - int x_stride = x_width; - int x_idx = x_row * x_stride + x_column; - - int x_row_end = min(x_row + SHUF_BLOCKSIZE_Y, x_height); - int x_idx_end = x_row_end * x_stride + x_column; - - int s_column = x_map[x_column]; - int s_idx = x_row * x_stride + s_column; - - while (x_idx < x_idx_end) - { - x_new[x_idx] = x[s_idx]; - x_idx += x_stride; - s_idx += x_stride; - } -} - -// Remap columns in x to correspond to sequential group index before matmul -// -// perform x -> seq_x such that seq_x @ seq_w == x @ w - -void column_remap_cuda -( - const half* x, - half* x_new, - const int x_height, - const int x_width, - const uint32_t* x_map -) -{ - dim3 threads(SHUF_BLOCKSIZE_X, 1, 1); - - dim3 blocks - ( - (x_width + SHUF_BLOCKSIZE_X - 1) / SHUF_BLOCKSIZE_X, - (x_height + SHUF_BLOCKSIZE_Y - 1) / SHUF_BLOCKSIZE_Y, - 1 - ); - - column_remap_kernel<<>>(x, x_new, x_width, x_height, x_map); -} diff --git a/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cuh b/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cuh deleted file mode 100644 index 0364e38c4..000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cuh +++ /dev/null @@ -1,19 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _column_remap_cuh -#define _column_remap_cuh - -#include -#include -#include - -void column_remap_cuda -( - const half* x, - half* x_new, - const int x_height, - const int x_width, - const uint32_t* x_map -); - -#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/cu_compat.cuh b/colossalai/kernel/cuda_native/csrc/gptq/cu_compat.cuh deleted file mode 100644 index c5258813e..000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/cu_compat.cuh +++ /dev/null @@ -1,58 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _cuda_compat_cuh -#define _cuda_compat_cuh - -// atomicAdd for half types, to support CC < 7.x - -__device__ __forceinline__ void atomicAdd_half(half* address, half val) -{ - unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); - unsigned int old = *address_as_ui; - unsigned int assumed; - - do - { - assumed = old; - __half_raw hsum; - hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); - half tmpres = __hadd(hsum, val); - hsum = __half_raw(tmpres); - old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; - old = atomicCAS(address_as_ui, assumed, old); - } - while (assumed != old); -} - -// atomicAdd for half2 types - -__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) -{ - unsigned int* address_as_ui = (unsigned int*)address; - unsigned int old = *address_as_ui; - unsigned int assumed; - do - { - assumed = old; - half2 old_val = *((half2*)&old); - half2 new_val = __hadd2(old_val, val); - old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val)); - } - while (assumed != old); -} - -// - -#if defined(__CUDA_ARCH__) || defined(USE_ROCM) -#if __CUDA_ARCH__ < 700 || defined(USE_ROCM) - -__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); } - -#if __CUDA_ARCH__ < 600 || defined(USE_ROCM) -__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); } -#endif - -#endif -#endif - -#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cu b/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cu deleted file mode 100644 index 4416027c8..000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cu +++ /dev/null @@ -1,75 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#define _cuda_buffers_cu -#include "cuda_buffers.cuh" - -CudaBuffers* g_buffers[CUDA_MAX_DEVICES] = {NULL}; -// __constant__ half2 q4_table[16][256]; -// half2 q4_table_host[16][256]; -// bool q4_table_init = false; - -CudaBuffers::CudaBuffers -( - int _device, - int _temp_state_size, - half* _temp_state, - half* _temp_dq -) : - device(_device), - temp_state_size(_temp_state_size), - temp_state(_temp_state), - temp_dq(_temp_dq) -{ - cudaSetDevice(_device); - - cudaStreamCreate(&alt_stream_1); - cudaStreamCreate(&alt_stream_2); - cudaStreamCreate(&alt_stream_3); - cudaEventCreate(&alt_stream_1_done); - cudaEventCreate(&alt_stream_2_done); - cudaEventCreate(&alt_stream_3_done); -} - -CudaBuffers::~CudaBuffers() -{ - cudaStreamDestroy(alt_stream_1); - cudaStreamDestroy(alt_stream_2); - cudaStreamDestroy(alt_stream_3); - cudaEventDestroy(alt_stream_1_done); - cudaEventDestroy(alt_stream_2_done); - cudaEventDestroy(alt_stream_3_done); -} - -CudaBuffers* get_buffers(const int device_index) -{ - return g_buffers[device_index]; -} - -void prepare_buffers_cuda -( - int _device, - int _temp_state_size, - half* _temp_state, - half* _temp_dq -) -{ - CudaBuffers* buffers = new CudaBuffers - ( - _device, - _temp_state_size, - _temp_state, - _temp_dq - ); - - g_buffers[_device] = buffers; -} - -void cleanup_buffers_cuda() -{ - for (int i = 0; i < CUDA_MAX_DEVICES; i++) - { - if (!g_buffers[i]) continue; - delete g_buffers[i]; - g_buffers[i] = NULL; - } -} diff --git a/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cuh b/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cuh deleted file mode 100644 index 0bf2057c6..000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cuh +++ /dev/null @@ -1,55 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _cuda_buffers_cuh -#define _cuda_buffers_cuh - -#include -#include -#include -#include - -const int CUDA_MAX_DEVICES = 16; - -// #ifndef _cuda_buffers_cu -// extern __constant__ half2 q4_table[16][256]; -// #endif - -class CudaBuffers -{ -public: - int device; - - half* temp_state; // [max_hidden_rows * intermediate_size] - int temp_state_size; - half* temp_dq; // size of largest quant tensor * 8 - - cudaStream_t alt_stream_1; - cudaStream_t alt_stream_2; - cudaStream_t alt_stream_3; - cudaEvent_t alt_stream_1_done; - cudaEvent_t alt_stream_2_done; - cudaEvent_t alt_stream_3_done; - - CudaBuffers - ( - int _device, - int _temp_state_size, - half* _temp_state, - half* _temp_dq - ); - ~CudaBuffers(); -}; - -CudaBuffers* get_buffers(const int device_index); - -void prepare_buffers_cuda -( - int _device, - int _temp_state_size, - half* _temp_state, - half* _temp_dq -); - -void cleanup_buffers_cuda(); - -#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/hip_compat.cuh b/colossalai/kernel/cuda_native/csrc/gptq/hip_compat.cuh deleted file mode 100644 index 5cd2e8553..000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/hip_compat.cuh +++ /dev/null @@ -1,49 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _hip_compat_cuh -#define _hip_compat_cuh - -// Workaround for a bug in hipamd, backported from upstream. -__device__ __forceinline__ __half __compat_hrcp(__half x) { - return __half_raw{ - static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))}; -} - -__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) { - return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)), - static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))}; -} - -#define hrcp __compat_hrcp -#define h2rcp __compat_h2rcp - -// Workaround for hipify_python using rocblas instead of hipblas. -__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle, - hipblasOperation_t transA, - hipblasOperation_t transB, - int m, - int n, - int k, - const half* alpha, - const half* AP, - int lda, - const half* BP, - int ldb, - const half* beta, - half* CP, - int ldc) { - return hipblasHgemm(handle, transA, transB, m, n, k, - reinterpret_cast(alpha), - reinterpret_cast(AP), lda, - reinterpret_cast(BP), ldb, - reinterpret_cast(beta), - reinterpret_cast(CP), ldc); -} - -#define rocblas_handle hipblasHandle_t -#define rocblas_operation_none HIPBLAS_OP_N -#define rocblas_get_stream hipblasGetStream -#define rocblas_set_stream hipblasSetStream -#define rocblas_hgemm __compat_hipblasHgemm - -#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp b/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp deleted file mode 100644 index bcc0e4390..000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp +++ /dev/null @@ -1,254 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#include -#include -#include -#include -#include -#include -#include -#include "util.cuh" -#include "tuning.h" -#include "cuda_buffers.cuh" -#include "q4_matrix.cuh" -#include "q4_matmul.cuh" -#include "column_remap.cuh" - -// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a -// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of -// exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console. - -void check_cuda(cudaError_t ret) -{ - switch (ret) - { - case cudaSuccess: - break; - - case cudaUnspecified: - printf(" **** Unspecified error\n"); - TORCH_CHECK(false, "CUDA error"); - break; - - default: - printf(" **** CUDA error\n"); \ - printf(" **** %s\n", cudaGetErrorString(ret)); \ - TORCH_CHECK(false, "CUDA error"); \ - break; - } -} - -// Some decluttering macros - -#define STRINGIFY_(__x) #__x -#define STRINGIFY(__x) STRINGIFY_(__x) -#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) -#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) -#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") -#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") -#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod)) -#define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small") - -#define TORCH_CHECK_DEVICE_INDEX(__index) \ -do { \ - TORCH_CHECK(__index >= 0, "no device index"); \ - TORCH_CHECK(__index < CUDA_MAX_DEVICES, "invalid device index"); \ -} while(0) - -#define TORCH_CHECK_QUANT(__w, __w_scales, __w_zeros, __seq_g_idx, __x_map) \ -do { \ - TORCH_CHECK_DTYPE(__w, kInt); \ - TORCH_CHECK_DTYPE(__w_scales, kHalf); \ - TORCH_CHECK_DTYPE(__w_zeros, kInt); \ - TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \ - TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \ - TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \ - TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \ -} while(0) - -int get_groupsize(torch::Tensor w, torch::Tensor w_zeros) -{ - int groupsize = w.size(0) * 8 / w_zeros.size(0); - TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]") - return groupsize; -} - - -// Tuning parameters - -ExLlamaTuning tuningParams; - -void set_tuning_params -( - int matmul_recons_thd, - bool matmul_fused_remap, - bool matmul_no_half2 -) -{ - tuningParams.matmul_recons_thd = matmul_recons_thd; - tuningParams.matmul_fused_remap = matmul_fused_remap; - tuningParams.matmul_no_half2 = matmul_no_half2; -} - - -// Release all unmanaged objects allocated by the extension - -void cleanup() -{ - cleanup_buffers_cuda(); - g_q4_free_matrices(); -} - - -// Prepare buffers for forward pass - -void prepare_buffers -( - torch::Device device, - torch::Tensor temp_state, - torch::Tensor temp_dq -) -{ - int device_index = device.index(); - TORCH_CHECK_DEVICE_INDEX(device_index); - const at::cuda::OptionalCUDAGuard device_guard(device); - - prepare_buffers_cuda - ( - device_index, - // buffer size used for sanity checks - temp_state.numel(), - (half*) temp_state.data_ptr(), - (half*) temp_dq.data_ptr() - ); -} - - -// Create Q4Matrix, return handle - -uintptr_t make_q4 -( - torch::Tensor qweight, - torch::Tensor qzeros, - torch::Tensor scales, - torch::Tensor g_idx, - int device -) -{ - TORCH_CHECK_DTYPE(qweight, kInt); - TORCH_CHECK_DTYPE(qzeros, kInt); - TORCH_CHECK_DTYPE(scales, kHalf); - TORCH_CHECK_DTYPE_OPT(g_idx, kInt); - TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8); - TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1); - TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1); - - int width = qweight.size(1); - int height = qweight.size(0) * 8; - int groups = qzeros.size(0); - - Q4Matrix* m = new Q4Matrix - ( - height, - width, - groups, - - (uint32_t*) qweight.data_ptr(), - (uint32_t*) qzeros.data_ptr(), - (half*) scales.data_ptr(), - g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(), - - device - ); - - g_q4_keep_matrix(m); - return reinterpret_cast (m); -} - - -// Matmul half @ quant -> half - -void q4_matmul -( - torch::Tensor x, - uintptr_t w, - torch::Tensor out -) -{ - Q4Matrix* wm = reinterpret_cast (w); - - TORCH_CHECK_DTYPE(x, kHalf); - TORCH_CHECK_DTYPE(out, kHalf); - TORCH_CHECK_SHAPES(x, 0, out, 0, 1); - TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes") - - const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); - - int x_height = x.size(0); - - if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd) - { - q4_matmul_cuda - ( - &tuningParams, - (half*) x.data_ptr(), - x_height, - wm, - (half*) out.data_ptr() - ); - } - else - { - q4_matmul_recons_cuda - ( - &tuningParams, - (half*) x.data_ptr(), - x_height, - wm, - (half*) out.data_ptr(), - at::cuda::getCurrentCUDABlasHandle() - ); - } -} - - -// Remap columns in half tensor - -void column_remap -( - torch::Tensor x, - torch::Tensor x_new, - torch::Tensor x_map -) -{ - TORCH_CHECK_DTYPE(x, kHalf); - TORCH_CHECK_DTYPE(x_new, kHalf); - TORCH_CHECK_DTYPE(x_map, kInt); - TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1); - - int height = x.size(0); - int width = x.size(1); - - TORCH_CHECK_BUFFER_SIZE(x_new, height * width); - - const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); - - column_remap_cuda - ( - (half*) x.data_ptr(), - (half*) x_new.data_ptr(), - height, - width, - (uint32_t*) x_map.data_ptr() - ); -} - - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ - m.def("set_tuning_params", &set_tuning_params, "set_tuning_params"); - m.def("prepare_buffers", &prepare_buffers, "prepare_buffers"); - m.def("cleanup", &cleanup, "cleanup"); - m.def("make_q4", &make_q4, "make_q4"); - m.def("q4_matmul", &q4_matmul, "q4_matmul"); -} diff --git a/colossalai/kernel/cuda_native/csrc/gptq/matrix.cuh b/colossalai/kernel/cuda_native/csrc/gptq/matrix.cuh deleted file mode 100644 index 2fd5ab0b3..000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/matrix.cuh +++ /dev/null @@ -1,294 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _matrix_cuh -#define _matrix_cuh - -#include -#include - -class MatrixView_half -{ -public: - const half* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } - __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } - __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); } - __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; } -}; - -class MatrixView_half_rw -{ -public: - half* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } - __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } - __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; } - __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; } - __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; } -}; - -class MatrixView_q4_row -{ -public: - const uint32_t* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ int item(int row, int column) const - { - int shift = (column & 0x07) * 4; - return (data[row * width / 8 + column / 8] >> shift) & 0x0f; - } -}; - -class MatrixView_q4_column -{ -public: - const uint32_t* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ int item(int row, int column) const - { - int shift = (row & 0x07) * 4; - return (data[row / 8 * width + column] >> shift) & 0x0f; - } - - __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; } - __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; } -}; - -// TODO: Rewrite all these dot product functions using functors or something, move to q4_matmul.cu - -// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale - -__device__ __forceinline__ half2 dot_product_8 -( - const half2 acc, - MatrixView_half& h_, - const int h_row, - const int h_column, // divisible by 8 - MatrixView_q4_column& v_, - const int v_row, // divisible by 8 - const int v_column, - const half2 v_scale_2, - const uint32_t v_zero, // + 1 (!!) - const int count -) -{ - const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column); - const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); - half2 result = acc; - - for (int i = 0; i < count; i++) - { - uint32_t v_read = *v_ptr; v_ptr += v_.width; - - half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); - half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); - half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); - half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); - half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); - half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); - half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); - half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); - - half2 v_01 = __halves2half2(v_0, v_1); - half2 v_23 = __halves2half2(v_2, v_3); - half2 v_45 = __halves2half2(v_4, v_5); - half2 v_67 = __halves2half2(v_6, v_7); - -// half2 v_01 = q4_table[v_zero - 1][(v_read ) & 0xff]; // (constant memory is too slow apparently) -// half2 v_23 = q4_table[v_zero - 1][(v_read >> 8) & 0xff]; -// half2 v_45 = q4_table[v_zero - 1][(v_read >> 16) & 0xff]; -// half2 v_67 = q4_table[v_zero - 1][(v_read >> 24) ]; - - half2 tmp = __hmul2(*h_ptr++, v_01); - tmp = __hfma2(*h_ptr++, v_23, tmp); - tmp = __hfma2(*h_ptr++, v_45, tmp); - tmp = __hfma2(*h_ptr++, v_67, tmp); - result = __hfma2(v_scale_2, tmp, result); - } - - return result; -} - -__device__ __forceinline__ half dot_product_8_h -( - const half acc, - MatrixView_half& h_, - const int h_row, - const int h_column, // divisible by 8 - MatrixView_q4_column& v_, - const int v_row, // divisible by 8 - const int v_column, - const half v_scale, - const uint32_t v_zero, // + 1 (!!) - const int count -) -{ - const half* h_ptr = h_.item_ptr(h_row, h_column); - const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); - half result = acc; - - for (int i = 0; i < count; i++) - { - uint32_t v_read = *v_ptr; v_ptr += v_.width; - - half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); - half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); - half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); - half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); - half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); - half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); - half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); - half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); - - half tmp = __hmul(*h_ptr++, v_0); - tmp = __hfma(*h_ptr++, v_1, tmp); - tmp = __hfma(*h_ptr++, v_2, tmp); - tmp = __hfma(*h_ptr++, v_3, tmp); - tmp = __hfma(*h_ptr++, v_4, tmp); - tmp = __hfma(*h_ptr++, v_5, tmp); - tmp = __hfma(*h_ptr++, v_6, tmp); - tmp = __hfma(*h_ptr++, v_7, tmp); - result = __hfma(v_scale, tmp, result); - } - - return result; -} - -// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map - -__device__ __forceinline__ half2 dot_product_8_x_map -( - const half2 acc, - MatrixView_half& h_, - const int h_row, - const int h_column, // divisible by 8 - MatrixView_q4_column& v_, - const int v_row, // divisible by 8 - const int v_column, - const half2 v_scale_2, - const uint32_t v_zero, // + 1 (!!) - const int count, - const uint32_t* x_map -) -{ - const half* h_ptr = h_.item_ptr(h_row, 0); - const uint32_t* x_map_ptr = x_map + h_column; - const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); - half2 result = acc; - - for (int i = 0; i < count; i++) - { - uint32_t v_read = *v_ptr; v_ptr += v_.width; - - half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); - half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); - half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); - half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); - half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); - half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); - half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); - half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); - - half2 v_01 = __halves2half2(v_0, v_1); - half2 v_23 = __halves2half2(v_2, v_3); - half2 v_45 = __halves2half2(v_4, v_5); - half2 v_67 = __halves2half2(v_6, v_7); - - half h_0 = h_ptr[*x_map_ptr++]; - half h_1 = h_ptr[*x_map_ptr++]; - half h_2 = h_ptr[*x_map_ptr++]; - half h_3 = h_ptr[*x_map_ptr++]; - half h_4 = h_ptr[*x_map_ptr++]; - half h_5 = h_ptr[*x_map_ptr++]; - half h_6 = h_ptr[*x_map_ptr++]; - half h_7 = h_ptr[*x_map_ptr++]; - - half2 h_01 = __halves2half2(h_0, h_1); - half2 h_23 = __halves2half2(h_2, h_3); - half2 h_45 = __halves2half2(h_4, h_5); - half2 h_67 = __halves2half2(h_6, h_7); - - half2 tmp = __hmul2(h_01, v_01); - tmp = __hfma2(h_23, v_23, tmp); - tmp = __hfma2(h_45, v_45, tmp); - tmp = __hfma2(h_67, v_67, tmp); - result = __hfma2(v_scale_2, tmp, result); - } - - return result; -} - -__device__ __forceinline__ half dot_product_8_x_map_h -( - const half acc, - MatrixView_half& h_, - const int h_row, - const int h_column, // divisible by 8 - MatrixView_q4_column& v_, - const int v_row, // divisible by 8 - const int v_column, - const half v_scale, - const uint32_t v_zero, // + 1 (!!) - const int count, - const uint32_t* x_map -) -{ - const half* h_ptr = h_.item_ptr(h_row, 0); - const uint32_t* x_map_ptr = x_map + h_column; - const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); - half result = acc; - - for (int i = 0; i < count; i++) - { - uint32_t v_read = *v_ptr; v_ptr += v_.width; - - half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); - half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); - half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); - half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); - half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); - half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); - half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); - half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); - - half tmp = __hmul(h_ptr[*x_map_ptr++], v_0); - tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp); - tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp); - tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp); - tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp); - tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp); - tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp); - tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp); - result = __hfma(v_scale, tmp, result); - } - - return result; -} - -#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cu b/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cu deleted file mode 100644 index f47daeb0e..000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cu +++ /dev/null @@ -1,260 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#include "q4_matmul.cuh" -#include "column_remap.cuh" -#include "util.cuh" -#include "matrix.cuh" -#include "cu_compat.cuh" -#include "cuda_buffers.cuh" -#if defined(USE_ROCM) -#include "hip_compat.cuh" -#endif - -const int THREADS_X = 32; // Block size and thread count along columns in w and out -const int THREADS_Y = 1; // Block size and thread count along rows in x and out - -typedef void (*fp_q4_matmul_kernel) -( - const half*, - const uint32_t*, - half*, - const half*, - const uint32_t*, - const int, - const int, - const int, - const int, - const int, - const uint32_t*, - bool -); - -template -__global__ void q4_matmul_kernel -( - const half* __restrict__ x, - const uint32_t* __restrict__ w, - half* __restrict__ out, - const half* __restrict__ w_scales, - const uint32_t* __restrict__ w_zeros, - const int height, - const int dim, - const int width, - const int groupsize, - const int block_size_z, - const uint32_t* __restrict__ x_map, - bool no_zero -) -{ - // Start of block - - int x_column = block_size_z * blockIdx.z; - int x_column_end = min(dim, block_size_z * (blockIdx.z + 1)); - - int w_column = THREADS_X * blockIdx.x + threadIdx.x; - int x_row = THREADS_Y * blockIdx.y + threadIdx.y; - - int iterations = (x_column_end - x_column) / 8; - - // Views - - MatrixView_half x_(x, height, dim); - MatrixView_half w_scales_(w_scales, dim / groupsize, width); - MatrixView_q4_row w_zeros_(w_zeros, dim / groupsize, width); - MatrixView_q4_column w_(w, dim, width); - MatrixView_half_rw out_(out, height, width); - - // Zero output - - if (!no_zero && blockIdx.z == 0 && (threadIdx.x & 1) == 0) - { - *((uint32_t*) out_.item_ptr(x_row, w_column)) = 0; - __syncthreads(); - } - - // Loop over part of x row (and w column) - - half2 acc = {}; - half acc_h = {}; - - if constexpr (use_groupsize) - { - // For quant matrices where groupsize divides BLOCK_SIZE_Z we always start on a group boundary, so this - // could be slightly faster - - for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize) - { - if constexpr (use_half2) - { - half2 w_scale = w_scales_.item_half2half2(group, w_column); - uint32_t w_zero = w_zeros_.item(group, w_column) + 1; - - if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); - else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); - } - else - { - half w_scale = w_scales_.item(group, w_column); - uint32_t w_zero = w_zeros_.item(group, w_column) + 1; - - if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); - else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); - } - } - } - else - { - // Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache - - for (int k = x_column; k < x_column + iterations * 8; k += 8) - { - if constexpr (use_half2) - { - int group = k / groupsize; - half2 w_scale = w_scales_.item_half2half2(group, w_column); - uint32_t w_zero = w_zeros_.item(group, w_column) + 1; - - if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); - else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); - } - else - { - int group = k / groupsize; - half w_scale = w_scales_.item(group, w_column); - uint32_t w_zero = w_zeros_.item(group, w_column) + 1; - - if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); - else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); - } - } - } - - // Add to block result - - if constexpr (use_half2) - { - half result = __hadd(__low2half(acc), __high2half(acc)); - atomicAdd(out_.item_ptr(x_row, w_column), result); - } - else - { - atomicAdd(out_.item_ptr(x_row, w_column), acc_h); - } -} - -fp_q4_matmul_kernel q4_matmul_kernel_pick(ExLlamaTuning* tuningParams, int block_size_z, int groupsize, uint32_t* x_map) -{ - // - if (tuningParams->matmul_no_half2) { - if (block_size_z % groupsize == 0) { - if (x_map) return q4_matmul_kernel; - else return q4_matmul_kernel; - } else { - if (x_map) return q4_matmul_kernel; - else return q4_matmul_kernel; - } - } else { - if (block_size_z % groupsize == 0) - { - if (x_map) return q4_matmul_kernel; - else return q4_matmul_kernel; - } else { - if (x_map) return q4_matmul_kernel; - else return q4_matmul_kernel; - } - } -}; - -// Compute y = x @ w - -void q4_matmul_cuda -( - ExLlamaTuning* tuningParams, - const half* x, - const int x_height, - const Q4Matrix* w, - half* out, - bool no_zero, - cudaStream_t alt_stream -) -{ - int height = x_height; - int dim = w->height; - int width = w->width; - - cudaSetDevice(w->device); - - uint32_t* x_map = w->cuda_x_map; - const half* x_mapped = x; - if (x_map && !tuningParams->matmul_fused_remap && !alt_stream) - { - CudaBuffers* buffers = get_buffers(w->device); - column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map); - x_mapped = buffers->temp_state; - x_map = NULL; - } - - int block_size_z; - if (w->width == 4096) block_size_z = 384; // 7B - else if (w->width == 11008) block_size_z = 256; - else if (w->width == 5120) block_size_z = 384; // 13B - else if (w->width == 13824) block_size_z = 256; - else if (w->width == 6656) block_size_z = 256; // 33B - else if (w->width == 17920) block_size_z = 128; - else block_size_z = 256; - - //if (!no_zero) cudaMemsetAsync(out, 0, x_height * w->width * sizeof(half)); - - dim3 threads(THREADS_X, THREADS_Y, 1); - - dim3 blocks - ( - (width + threads.x - 1) / threads.x, - (height + threads.y - 1) / threads.y, - (dim + block_size_z - 1) / block_size_z - ); - - fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map); - - kernel<<>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero); -} - -void q4_matmul_recons_cuda -( - ExLlamaTuning* tuningParams, - const half* x, - const int x_height, - Q4Matrix* w, - half* out, - const cublasHandle_t handle, - bool no_zero -) -{ - int height = x_height; - int dim = w->height; - int width = w->width; - - cudaSetDevice(w->device); - CudaBuffers* buffers = get_buffers(w->device); - - const half* x_mapped = x; - if (w->cuda_x_map) - { - TORCH_CHECK(buffers->temp_state_size >= x_height * dim, "temp_state buffer is too small"); - column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map); - x_mapped = buffers->temp_state; - } - - w->reconstruct(buffers->temp_dq); - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700 - const float alpha = 1.0f; - const float beta = no_zero ? 1.0f : 0.0f; - cublasSgemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, CUDA_R_16F, width, - x_mapped, CUDA_R_16F, dim, &beta, out, CUDA_R_16F, width); -#else - const half alpha = __float2half(1.0f); - const half beta = no_zero ? __float2half(1.0f) : __float2half(0.0f); - cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, width, x_mapped, dim, &beta, out, width); -#endif -} diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cuh b/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cuh deleted file mode 100644 index 09f3e1a63..000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cuh +++ /dev/null @@ -1,43 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _q4_matmul_cuh -#define _q4_matmul_cuh - -#include -#include -#include -#include -#include - -#include "q4_matrix.cuh" -#include "tuning.h" - -// Workaround for hipify_python using rocblas instead of hipblas. -#if defined(USE_ROCM) -#include -#define rocblas_handle hipblasHandle_t -#endif - -void q4_matmul_cuda -( - ExLlamaTuning* tuningParams, - const half* x, - const int x_height, - const Q4Matrix* w, - half* out, - bool no_zero = false, - cudaStream_t alt_stream = NULL -); - -void q4_matmul_recons_cuda -( - ExLlamaTuning* tuningParams, - const half* x, - const int x_height, - Q4Matrix* w, - half* out, - const cublasHandle_t handle, - bool no_zero = false -); - -#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu deleted file mode 100644 index 9c61143f5..000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu +++ /dev/null @@ -1,225 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#include "q4_matrix.cuh" -#include -#include "util.cuh" -#include "matrix.cuh" - -using namespace std; - -const int UNSHUF_BLOCKSIZE_X = 64; - -const int RECONS_THREADS_X = 64; // Block size and thread count along columns in out, each thread converts 1 column -const int RECONS_THREADS_Y = 1; // Block size and thread count along rows in x and out, each thread converts 8 rows - -vector g_q4_matrices; - -void g_q4_keep_matrix(Q4Matrix* m) -{ - g_q4_matrices.push_back(m); -} - -void g_q4_free_matrices() -{ - for (const auto& m : g_q4_matrices) delete m; - g_q4_matrices.clear(); -} - -Q4Matrix::Q4Matrix -( - const int _height, - const int _width, - const int _groups, - - uint32_t* _qweight, - uint32_t* _qzeros, - half* _scales, - uint32_t* _g_idx, - - const int _device -) : - height(_height), - width(_width), - groups(_groups), - device(_device) -{ - cudaSetDevice(device); - - cuda_qweight = _qweight; - cuda_qzeros = _qzeros; - cuda_scales = _scales; - - groupsize = height / groups; - - if (_g_idx) make_sequential(_g_idx); -} - -Q4Matrix::~Q4Matrix() -{ -} - -// Make sequential - -__global__ void make_sequential_kernel -( - const uint32_t* __restrict__ w, - uint32_t* __restrict__ w_new, - const uint32_t* __restrict__ x_map, - const int w_height, - const int w_width -) -{ - const uint64_t* w2 = (uint64_t*) w; - uint64_t* w_new2 = (uint64_t*) w_new; - int w2_stride = w_width >> 1; - - int w2_column = UNSHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x; - if (w2_column >= w2_stride) return; - - int w_new2_row = blockIdx.y; - - int x_map_idx = w_new2_row << 3; - - uint64_t dst = 0; - - #pragma unroll - for (int i = 0; i < 8; i++) - { - int source_row = x_map[x_map_idx++]; - - int w2_row = source_row >> 3; - int w2_subrow = source_row & 0x07; - int w2_row_shift = w2_subrow << 2; - int wnew2_row_shift = i << 2; - - uint64_t src = w2[w2_row * w2_stride + w2_column]; - src >>= w2_row_shift; - src &= 0x0000000f0000000f; - src <<= wnew2_row_shift; - dst |= src; - } - - w_new2[w_new2_row * w2_stride + w2_column] = dst; -} - -void Q4Matrix::make_sequential(const uint32_t* cpu_g_idx) -{ - uint32_t* cuda_new_qweight = NULL; - cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t)); - cudaMalloc(&cuda_x_map, height * sizeof(uint32_t)); // TODO: Should probably be allocated in PyTorch - - uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t)); - uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t)); - uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t)); - - // Group histogram - - for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++; - - // Group map - - for (int i = 0, acc = 0; i < groups; i++) - { - short tmp = cpu_g_idx_map[i]; - cpu_g_idx_map[i] = acc; - acc += tmp; - } - - // X map (inverse) - - for (int row = 0; row < height; row++) - { - uint32_t target_group = cpu_g_idx[row]; - uint32_t target_row = cpu_g_idx_map[target_group]; - cpu_g_idx_map[target_group]++; - cpu_x_map_inv[row] = target_row; - } - - // X map - - for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row; - - // Move to CUDA - - cudaMemcpyAsync(cuda_x_map, cpu_x_map, height * sizeof(uint32_t), cudaMemcpyHostToDevice); - - // Rearrange rows in w - - dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1); - dim3 blocks - ( - (width + UNSHUF_BLOCKSIZE_X * 2 - 1) / (UNSHUF_BLOCKSIZE_X * 2), - height / 8, - 1 - ); - - make_sequential_kernel<<>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width); - - // Replace qweights - - cudaMemcpyAsync(cuda_qweight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice); - - // Cleanup - - cudaDeviceSynchronize(); - cudaFree(cuda_new_qweight); - free(cpu_g_idx_map); - free(cpu_x_map); - free(cpu_x_map_inv); -} - -__global__ void reconstruct_kernel -( - const uint32_t* __restrict__ w, - half* __restrict__ out, // (y) - const half* __restrict__ w_scales, - const uint32_t* __restrict__ w_zeros, - const int height, - const int width, - const int groupsize -) -{ - // Start of block - - int column = RECONS_THREADS_X * blockIdx.x + threadIdx.x; - int row = (RECONS_THREADS_Y * blockIdx.y + threadIdx.y) * 8; - if (column >= width) return; - - // Views - - MatrixView_q4_column w_(w, height, width); - MatrixView_half_rw out_(out, height, width); - MatrixView_half w_scales_(w_scales, height / groupsize, width); - MatrixView_q4_row w_zeros_(w_zeros, height / groupsize, width); - - // Groupsize version - - int group = row / groupsize; - - half w_scale = w_scales_.item(group, column); - uint32_t w_zero = w_zeros_.item(group, column) + 1; - - uint32_t w_read = w_.item_uint32_t(row, column); - half* out_ptr = out_.item_ptr(row, column); - - #pragma unroll - for (int s = 0; s < 32; s += 4) - { - half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale); - *out_ptr = w_item; out_ptr += out_.width; - } -} - -void Q4Matrix::reconstruct(half* out) -{ - dim3 threads(RECONS_THREADS_X, RECONS_THREADS_Y, 1); - - dim3 blocks - ( - (width + threads.x - 1) / threads.x, - (height / 8 + threads.y - 1) / threads.y, - 1 - ); - - reconstruct_kernel<<>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize); -} diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh deleted file mode 100644 index 50cb72a41..000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh +++ /dev/null @@ -1,53 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _q4_matrix_cuh -#define _q4_matrix_cuh - -#include -#include -#include - -class Q4Matrix -{ -public: - - int device; - - int height; - int width; - int groups; - int groupsize; - - uint32_t* cuda_qweight = NULL; - uint32_t* cuda_qzeros = NULL; - half* cuda_scales = NULL; - uint32_t* cuda_x_map = NULL; - - Q4Matrix - ( - const int _height, - const int _width, - const int _groups, - - uint32_t* _qweight, - uint32_t* _qzeros, - half* _scales, - uint32_t* _g_idx, - - const int _device - ); - - ~Q4Matrix(); - - void reconstruct(half* out); - -private: - - void make_sequential(const uint32_t* cpu_g_idx); - -}; - -void g_q4_keep_matrix(Q4Matrix* m); -void g_q4_free_matrices(); - -#endif \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/gptq/tuning.h b/colossalai/kernel/cuda_native/csrc/gptq/tuning.h deleted file mode 100644 index e413b8a96..000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/tuning.h +++ /dev/null @@ -1,12 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _tuning_h -#define _tuning_h - -struct ExLlamaTuning { - int matmul_recons_thd; - bool matmul_fused_remap; - bool matmul_no_half2; -}; - -#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/util.cuh b/colossalai/kernel/cuda_native/csrc/gptq/util.cuh deleted file mode 100644 index 7b3975732..000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/util.cuh +++ /dev/null @@ -1,33 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _util_cuh -#define _util_cuh - -#include -#include -#include -#include - -#if defined(USE_ROCM) -#define cudaUnspecified hipErrorUnknown -#else -#define cudaUnspecified cudaErrorApiFailureBase -#endif - -// React to failure on return code != cudaSuccess - -#define _cuda_check(fn) \ -do { \ - {_cuda_err = fn;} \ - if (_cuda_err != cudaSuccess) goto _cuda_fail; \ -} while(false) - -// React to failure on return code == 0 - -#define _alloc_check(fn) \ -do { \ - if (!(fn)) { _cuda_err = cudaUnspecified; goto _cuda_fail; } \ - else _cuda_err = cudaSuccess; \ -} while(false) - -#endif diff --git a/colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu b/colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu deleted file mode 100644 index 58d26235a..000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu +++ /dev/null @@ -1,191 +0,0 @@ -#include "block_reduce.h" -#include "cuda_util.h" -#include "kernels.h" -#include "ls_cub.cuh" - -ls::cub::CachingDeviceAllocator g_allocator(true); - -template -__global__ void ls_cross_entropy_fw_kernel( - const T *__restrict__ inputs, const int *__restrict__ targets, - float *__restrict__ outputs, float *__restrict__ nll_loss_outputs, - const int padding_idx, const float epsilon, const int vocab_size) { - /* step1: compute each thread's max_logit and sum_exp_logit, store in - * max_input, sum_exp_logit */ - const int block_start = blockIdx.x * vocab_size; - const int left_idx = block_start + threadIdx.x; - const int right_idx = (blockIdx.x + 1) * vocab_size; - float max_input[1] = {REDUCE_FLOAT_INF_NEG}; - float sum_logits[2] = {0.f, 0.f}; // logit and logit exp - int target_tid = targets[blockIdx.x]; - - if (target_tid == padding_idx) { - if (threadIdx.x == 0) { - nll_loss_outputs[blockIdx.x] = 0.f; - outputs[blockIdx.x] = 0.f; - } - return; - } - - for (int i = left_idx; i < right_idx; i += blockDim.x) { - max_input[0] = fmaxf(max_input[0], static_cast(inputs[i])); - } - blockReduce(max_input); - __shared__ float s_max_input; - if (threadIdx.x == 0) { - s_max_input = max_input[0]; - } - __syncthreads(); - - for (int i = left_idx; i < right_idx; i += blockDim.x) { - float logit = static_cast(inputs[i]) - s_max_input; - sum_logits[0] += logit; - sum_logits[1] += expf(logit); - } - - blockReduce(sum_logits); - __shared__ float s_sum_logit; - __shared__ float s_sum_exp; - if (threadIdx.x == 0) { - s_sum_logit = sum_logits[0]; - s_sum_exp = sum_logits[1]; - } - __syncthreads(); - - float eps_i = epsilon / (vocab_size - 1); - if (threadIdx.x == 0) { - // neg_log_prob = log(sum(exp(x - x_max))) - (x - x_max) - float nll_loss = logf(s_sum_exp) - - static_cast(inputs[block_start + target_tid]) + - s_max_input; - nll_loss_outputs[blockIdx.x] = nll_loss; - float sum_nll_loss = vocab_size * logf(s_sum_exp) - s_sum_logit; - outputs[blockIdx.x] = - (1.f - epsilon - eps_i) * nll_loss + eps_i * sum_nll_loss; - } -} - -template -__global__ void ls_cross_entropy_bw_kernel( - const float *__restrict__ grad_outputs, const T *__restrict__ inputs, - const int *__restrict__ targets, T *__restrict__ grad_inputs, - const int padding_idx, const float epsilon, const int vocab_size) { - /* step1: compute each thread's max_logit and sum_exp_logit, store in - * max_input, sum_exp_logit */ - const int block_start = blockIdx.x * vocab_size; - const int left_idx = block_start + threadIdx.x; - const int right_idx = (blockIdx.x + 1) * vocab_size; - float max_input[1] = {REDUCE_FLOAT_INF_NEG}; - float sum_logits[1] = {0.f}; - const float grad_out = static_cast(grad_outputs[0]); - int target_tid = targets[blockIdx.x]; - - if (target_tid == padding_idx) { - for (int i = left_idx; i < right_idx; i += blockDim.x) { - grad_inputs[i] = 0.f; - } - return; - } - - for (int i = left_idx; i < right_idx; i += blockDim.x) { - max_input[0] = fmaxf(max_input[0], static_cast(inputs[i])); - } - blockReduce(max_input); - __shared__ float s_max_input; - if (threadIdx.x == 0) { - s_max_input = max_input[0]; - } - __syncthreads(); - - for (int i = left_idx; i < right_idx; i += blockDim.x) { - float logit = static_cast(inputs[i]) - s_max_input; - sum_logits[0] += expf(logit); - } - - blockReduce(sum_logits); - __shared__ float s_sum_exp; - if (threadIdx.x == 0) { - s_sum_exp = sum_logits[0]; - } - __syncthreads(); - - float eps_i = epsilon / (vocab_size - 1); - float nll_weight = 1.0 - epsilon - eps_i; - - for (int i = left_idx; i < right_idx; i += blockDim.x) { - float prob = expf(static_cast(inputs[i]) - s_max_input) / s_sum_exp; - float grad = 0; - grad += (vocab_size * prob - 1) * eps_i; - grad += prob * nll_weight; - if ((i - block_start) == target_tid) { - grad -= nll_weight; - } - grad_inputs[i] = grad_out * grad; - } -} - -template -void launch_cross_entropy_fw(const T *inputs_ptr, const int *targets_ptr, - float *outputs_ptr, float *nll_loss_ptr, - float *loss_buffer, const int padding_idx, - const float epsilon, const int batch_size, - const int seq_len, const int vocab_size, - cudaStream_t stream) { - int grid_dim = batch_size * seq_len; - float *nll_loss_buffer = loss_buffer + grid_dim; - ls_cross_entropy_fw_kernel<<>>( - inputs_ptr, targets_ptr, loss_buffer, nll_loss_buffer, padding_idx, - epsilon, vocab_size); - - int num_items = grid_dim; - void *d_temp_storage = NULL; - size_t temp_storage_bytes = 0; - CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes, - loss_buffer, outputs_ptr, - num_items, stream)); - CHECK_GPU_ERROR( - g_allocator.DeviceAllocate(&d_temp_storage, temp_storage_bytes)); - CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes, - loss_buffer, outputs_ptr, - num_items, stream)); - CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes, - nll_loss_buffer, nll_loss_ptr, - num_items, stream)); - CHECK_GPU_ERROR(g_allocator.DeviceFree(d_temp_storage)); -} - -template void launch_cross_entropy_fw( - const float *inputs_ptr, const int *targets_ptr, float *outputs_ptr, - float *nll_loss_ptr, float *loss_buffer, const int padding_idx, - const float epsilon, const int batch_size, const int seq_len, - const int vocab_size, cudaStream_t stream); - -template void launch_cross_entropy_fw<__half>( - const __half *inputs_ptr, const int *targets_ptr, float *outputs_ptr, - float *nll_loss_ptr, float *loss_buffer, const int padding_idx, - const float epsilon, const int batch_size, const int seq_len, - const int vocab_size, cudaStream_t stream); - -template -void launch_cross_entropy_bw(const float *grad_outputs_ptr, const T *inputs_ptr, - const int *targets_ptr, T *grad_inputs_ptr, - const int padding_idx, const float epsilon, - const int batch_size, const int seq_len, - const int vocab_size, cudaStream_t stream) { - int grid_dim = batch_size * seq_len; - ls_cross_entropy_bw_kernel<<>>( - grad_outputs_ptr, inputs_ptr, targets_ptr, grad_inputs_ptr, padding_idx, - epsilon, vocab_size); -} - -template void launch_cross_entropy_bw( - const float *grad_outputs_ptr, const float *inputs_ptr, - const int *targets_ptr, float *grad_inputs_ptr, const int padding_idx, - const float epsilon, const int batch_size, const int seq_len, - const int vocab_size, cudaStream_t stream); - -template void launch_cross_entropy_bw<__half>( - const float *grad_outputs_ptr, const __half *inputs_ptr, - const int *targets_ptr, __half *grad_inputs_ptr, const int padding_idx, - const float epsilon, const int batch_size, const int seq_len, - const int vocab_size, cudaStream_t stream); diff --git a/colossalai/kernel/cuda_native/csrc/kernels/cublas_wrappers.cu b/colossalai/kernel/cuda_native/csrc/kernels/cublas_wrappers.cu deleted file mode 100644 index 09f34763f..000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/cublas_wrappers.cu +++ /dev/null @@ -1,88 +0,0 @@ -/* Copyright 2021 The LightSeq Team - Copyright Microsoft DeepSpeed - This file is adapted from Microsoft DeepSpeed - Licensed under the MIT License. -*/ -#include "cublas_wrappers.h" - -int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, - const float *alpha, const float *beta, const float *A, - const float *B, float *C, cublasGemmAlgo_t algo) { - cublasStatus_t status = - cublasGemmEx(handle, transa, transb, m, n, k, (const void *)alpha, - (const void *)A, CUDA_R_32F, (transa == CUBLAS_OP_N) ? m : k, - (const void *)B, CUDA_R_32F, (transb == CUBLAS_OP_N) ? k : n, - (const void *)beta, C, CUDA_R_32F, m, CUDA_R_32F, algo); - - if (status != CUBLAS_STATUS_SUCCESS) { - fprintf(stderr, - "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", - m, n, k, (int)status); - return EXIT_FAILURE; - } - return 0; -} - -int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, - const float *alpha, const float *beta, const __half *A, - const __half *B, __half *C, cublasGemmAlgo_t algo) { - cublasStatus_t status = cublasGemmEx( - handle, transa, transb, m, n, k, (const void *)alpha, (const void *)A, - CUDA_R_16F, (transa == CUBLAS_OP_N) ? m : k, (const void *)B, CUDA_R_16F, - (transb == CUBLAS_OP_N) ? k : n, (const void *)beta, (void *)C, - CUDA_R_16F, m, CUDA_R_32F, algo); - - if (status != CUBLAS_STATUS_SUCCESS) { - fprintf(stderr, - "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", - m, n, k, (int)status); - return EXIT_FAILURE; - } - return 0; -} - -int cublas_strided_batched_gemm(cublasHandle_t handle, int m, int n, int k, - const float *alpha, const float *beta, - const float *A, const float *B, float *C, - cublasOperation_t op_A, cublasOperation_t op_B, - int stride_A, int stride_B, int stride_C, - int batch, cublasGemmAlgo_t algo) { - cublasStatus_t status = cublasGemmStridedBatchedEx( - handle, op_A, op_B, m, n, k, alpha, A, CUDA_R_32F, - (op_A == CUBLAS_OP_N) ? m : k, stride_A, B, CUDA_R_32F, - (op_B == CUBLAS_OP_N) ? k : n, stride_B, beta, C, CUDA_R_32F, m, stride_C, - batch, CUDA_R_32F, algo); - - if (status != CUBLAS_STATUS_SUCCESS) { - fprintf(stderr, - "!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, " - "error: %d) \n", - batch, m, n, k, (int)status); - return EXIT_FAILURE; - } - return 0; -} - -int cublas_strided_batched_gemm(cublasHandle_t handle, int m, int n, int k, - const float *alpha, const float *beta, - const __half *A, const __half *B, __half *C, - cublasOperation_t op_A, cublasOperation_t op_B, - int stride_A, int stride_B, int stride_C, - int batch, cublasGemmAlgo_t algo) { - cublasStatus_t status = cublasGemmStridedBatchedEx( - handle, op_A, op_B, m, n, k, alpha, A, CUDA_R_16F, - (op_A == CUBLAS_OP_N) ? m : k, stride_A, B, CUDA_R_16F, - (op_B == CUBLAS_OP_N) ? k : n, stride_B, beta, C, CUDA_R_16F, m, stride_C, - batch, CUDA_R_32F, algo); - - if (status != CUBLAS_STATUS_SUCCESS) { - fprintf(stderr, - "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", - m, n, k, (int)status); - return EXIT_FAILURE; - } - - return 0; -} diff --git a/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu b/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu deleted file mode 100644 index e5ac17308..000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu +++ /dev/null @@ -1,169 +0,0 @@ -#include -#include -#include -#include "cuda_util.h" - -/* GPU function guard */ -std::string _cudaGetErrorString(cudaError_t error) { - return cudaGetErrorString(error); -} - -std::string _cudaGetErrorString(cublasStatus_t error) { - switch (error) { - case CUBLAS_STATUS_SUCCESS: - return "CUBLAS_STATUS_SUCCESS"; - - case CUBLAS_STATUS_NOT_INITIALIZED: - return "CUBLAS_STATUS_NOT_INITIALIZED"; - - case CUBLAS_STATUS_ALLOC_FAILED: - return "CUBLAS_STATUS_ALLOC_FAILED"; - - case CUBLAS_STATUS_INVALID_VALUE: - return "CUBLAS_STATUS_INVALID_VALUE"; - - case CUBLAS_STATUS_ARCH_MISMATCH: - return "CUBLAS_STATUS_ARCH_MISMATCH"; - - case CUBLAS_STATUS_MAPPING_ERROR: - return "CUBLAS_STATUS_MAPPING_ERROR"; - - case CUBLAS_STATUS_EXECUTION_FAILED: - return "CUBLAS_STATUS_EXECUTION_FAILED"; - - case CUBLAS_STATUS_INTERNAL_ERROR: - return "CUBLAS_STATUS_INTERNAL_ERROR"; - - case CUBLAS_STATUS_NOT_SUPPORTED: - return "CUBLAS_STATUS_NOT_SUPPORTED"; - - case CUBLAS_STATUS_LICENSE_ERROR: - return "CUBLAS_STATUS_LICENSE_ERROR"; - } - return "CUBLAS_UNKNOW"; -} - -template -void check_gpu_error(T result, char const *const func, const char *const file, - int const line) { - if (result) { - throw std::runtime_error(std::string("[CUDA][ERROR] ") + +file + "(" + - std::to_string(line) + - "): " + (_cudaGetErrorString(result)) + "\n"); - } -} - -template void check_gpu_error(cudaError_t result, - char const *const func, - const char *const file, - int const line); -template void check_gpu_error(cublasStatus_t result, - char const *const func, - const char *const file, - int const line); - -template -void print_vec(const T *outv, std::string outn, int num_output_ele) { - std::cout << outn << ": "; - std::vector hout(num_output_ele, (T)0); - cudaMemcpy(hout.data(), outv, num_output_ele * sizeof(T), - cudaMemcpyDeviceToHost); - for (int i = 0; i < num_output_ele; i++) { - std::cout << hout[i] << ", "; - } - std::cout << std::endl; -} - -template <> -void print_vec<__half>(const __half *outv, std::string outn, - int num_output_ele) { - std::cout << outn << ": "; - std::vector<__half> hout(num_output_ele, (__half)0.f); - cudaMemcpy(hout.data(), outv, num_output_ele * sizeof(__half), - cudaMemcpyDeviceToHost); - for (int i = 0; i < num_output_ele; i++) { - std::cout << __half2float(hout[i]) << ", "; - } - std::cout << std::endl; -} - -template void print_vec(const float *outv, std::string outn, - int num_output_ele); - -template void print_vec(const int *outv, std::string outn, - int num_output_ele); - -template void print_vec<__half>(const __half *outv, std::string outn, - int num_output_ele); - -template -T *cuda_malloc(size_t ele_num) { - size_t byte_size = ele_num * sizeof(T); - T *pdata = nullptr; - CHECK_GPU_ERROR(cudaMalloc((void **)&pdata, byte_size)); - return pdata; -} - -template float *cuda_malloc(size_t ele_num); - -template __half *cuda_malloc<__half>(size_t ele_num); - -template uint8_t *cuda_malloc(size_t ele_num); - -void cuda_free(void *pdata) { - if (pdata != nullptr) { - cudaFree(pdata); - } -} - -template -struct _isnan { - __device__ bool operator()(T a) const { return isnan(a); } -}; - -template <> -struct _isnan<__half> { - __device__ bool operator()(const __half a) const { return __hisnan(a); } -}; - -template -struct _isinf { - __device__ bool operator()(T a) const { return isinf(a); } -}; - -template <> -struct _isinf<__half> { - __device__ bool operator()(const __half a) const { return __hisinf(a); } -}; - -template -void check_nan_inf(const T *data_ptr, int dsize, bool check_nan_inf, - std::string file, int line, cudaStream_t stream) { - // check_nan_inf = 0 for checking nan - // check_nan_inf = 1 for checking inf - bool res = false; - std::string msg = file + "(" + std::to_string(line) + "): "; - if (check_nan_inf) { - msg += "nan."; - res = thrust::transform_reduce(thrust::cuda::par.on(stream), data_ptr, - data_ptr + dsize, _isnan(), false, - thrust::logical_or()); - } else { - msg += "inf."; - res = thrust::transform_reduce(thrust::cuda::par.on(stream), data_ptr, - data_ptr + dsize, _isinf(), false, - thrust::logical_or()); - } - if (res) { - throw std::runtime_error(msg); - } - std::cout << msg << " [check pass]." << std::endl; -} - -template void check_nan_inf(const float *data_ptr, int dsize, - bool check_nan_inf, std::string file, - int line, cudaStream_t stream); - -template void check_nan_inf<__half>(const __half *data_ptr, int dsize, - bool check_nan_inf, std::string file, - int line, cudaStream_t stream); diff --git a/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu deleted file mode 100644 index ce0b017f1..000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu +++ /dev/null @@ -1,1002 +0,0 @@ -#include -#include - -#include "kernels.h" - -#include - - -namespace cg = cooperative_groups; - -curandStatePhilox4_32_10_t *curandstate; - -/** - * @brief element-wise activation function on device, like Relu, Gelu - * - * @tparam enum class ActivationType, kRelu, kGelu - * @tparam input type - * @param any shape of float and __half2 - * @return same shape and type with input - */ -template -__forceinline__ __device__ T activation_kernel(T x); - -template <> -__device__ float activation_kernel(float x) { - float cdf = - 0.5f * - (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); - return x * cdf; -} - -template <> -__device__ __half2 -activation_kernel(__half2 val) { - __half2 val_pow3 = __hmul2(val, __hmul2(val, val)); - float2 tmp_pow = __half22float2(val_pow3); - float2 tmp = __half22float2(val); - - tmp.x = - 0.5f * - (1.0f + tanhf((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x)))); - tmp.y = - 0.5f * - (1.0f + tanhf((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y)))); - return __hmul2(val, __float22half2_rn(tmp)); -} - -template <> -__device__ float activation_kernel(float x) { - return fmaxf(x, 0); -} - -template <> -__device__ __half2 -activation_kernel(__half2 x) { - return __floats2half2_rn(fmaxf(0.f, __half2float(x.x)), - fmaxf(0.f, __half2float(x.y))); -} - -/** - * @brief element-wise activation backward function on device - * - * @tparam enum class ActivationType - * @tparam input type - * @param any shape of float and __half2 - * @return same shape of input - */ -template -__forceinline__ __device__ T activation_bwd_kernel(T grad, T x); - -template <> -__device__ float activation_bwd_kernel(float grad, - float x) { - const float sqrt_param = 0.79788456080286535587989211986876f; - const float mul_param = 0.044715; - - float x2mul = x * x * mul_param; - float tan_h = tanhf(sqrt_param * (x + x * x2mul)); - float dg1 = 0.5f * (1.0f + tan_h); - float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h); - float dg3 = dg2 * 3 * x2mul; - return grad * (dg1 + dg2 + dg3); -} - -template <> -__device__ __half activation_bwd_kernel( - __half grad, __half x_half) { - float x = __half2float(x_half); - const float sqrt_param = 0.79788456080286535587989211986876f; - const float mul_param = 0.044715; - - float x2mul = x * x * mul_param; - float tan_h = tanhf(sqrt_param * (x + x * x2mul)); - float dg1 = 0.5f * (1.0f + tan_h); - float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h); - float dg3 = dg2 * 3 * x2mul; - return grad * __float2half(dg1 + dg2 + dg3); -} - -template <> -__device__ float activation_bwd_kernel(float grad, - float x) { - return x > 0.f ? grad : 0.f; -} - -template <> -__device__ __half -activation_bwd_kernel(__half grad, __half x) { - const __half half_zero = __float2half(0.f); - return x > half_zero ? grad : half_zero; -} - -template <> -__device__ __half2 activation_bwd_kernel( - __half2 grad2, __half2 x_half2) { - const __half half_zero = __float2half(0.f); - return __floats2half2_rn(x_half2.x > half_zero ? grad2.x : half_zero, - x_half2.y > half_zero ? grad2.y : half_zero); -} - -/** - * @brief init curand states in global memory - * - * @thread grid_dim * block*dim to suuport any size of states - * @param state persistant curand states - * @param seed seed to init states - * @return void - */ -__global__ void curand_init_kernel(curandStatePhilox4_32_10_t *state, - int seed) { - /* Each thread gets same seed, a different sequence - number, no offset */ - int id = threadIdx.x + blockIdx.x * blockDim.x; - curand_init(seed, id, 0, &state[id]); -} - -void launch_curand_init(int total_count, int dim, cudaStream_t stream) { - cudaMalloc(&curandstate, total_count * sizeof(curandStatePhilox4_32_10_t)); - int grid_dim = total_count >> 9; - curand_init_kernel<<>>( - curandstate, std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count()); -} - -/** - * @brief element-wise dropout, store dropped position in mask, it's not - * in-place - * - * @thread - * gridDim.x = total_count / 1024 - * blockDim.x = 1024 - * - * @param total_count total elements - * @param ratio drop ratio - * @param out any size of float and __half - * @param in same with out - * @param mask uint8 type, same size with out - * @param seed seed to curand - * @return void - */ -__global__ void ls_dropout_kernel(const int total_count, const float ratio, - float *__restrict__ out, - const float *__restrict__ in, - uint8_t *__restrict__ mask, const int seed) { - const float scale = 1.f / (1.f - ratio); - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 4 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - uint8_t m[4]; - - float4 *out4 = reinterpret_cast(out); - const float4 *data4 = reinterpret_cast(in); - uint32_t *mask4 = reinterpret_cast(mask); - float4 rand = curand_uniform4(&state); - - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - - uint32_t *m4 = reinterpret_cast(m); - mask4[i] = m4[0]; - - float4 input4 = data4[i]; - float4 res4; - res4.x = input4.x * scale * m[0]; - res4.y = input4.y * scale * m[1]; - res4.z = input4.z * scale * m[2]; - res4.w = input4.w * scale * m[3]; - out4[i] = res4; -} - -__global__ void ls_dropout_kernel(const int total_count, const float ratio, - __half *__restrict__ out, - const __half *__restrict__ in, - uint8_t *__restrict__ mask, const int seed) { - const float scale = 1.f / (1.f - ratio); - - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 8 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - - const float4 *vals_float4 = reinterpret_cast(in); - float4 *outs_float4 = reinterpret_cast(out); - uint64_t *mask8 = reinterpret_cast(mask); - - uint8_t m[8]; - float4 rand = curand_uniform4(&state); - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - rand = curand_uniform4(&state); - m[4] = (uint8_t)(rand.x > ratio); - m[5] = (uint8_t)(rand.y > ratio); - m[6] = (uint8_t)(rand.z > ratio); - m[7] = (uint8_t)(rand.w > ratio); - uint64_t *m8 = reinterpret_cast(m); - mask8[i] = *m8; - - float4 val_float4 = vals_float4[i]; - float4 out_float4; - __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); - __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); - __half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]); - __half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]); - __half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]); - __half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]); - out_half2[0] = __hmul2(val_half2[0], scale_mask_1); - out_half2[1] = __hmul2(val_half2[1], scale_mask_2); - out_half2[2] = __hmul2(val_half2[2], scale_mask_3); - out_half2[3] = __hmul2(val_half2[3], scale_mask_4); - outs_float4[i] = out_float4; -} - -/** - * @brief element-wise dropout backward with dropout mask, it's - * not in-place - * - * @thread - * gridDim.x = total_count / 1024 - * blockDim.x = 1024 - * - * @param total_count total elements - * @param ratio drop ratio - * @param in any size of float and __half - * @param mask uint8 type, same size with in - * @return void - */ -__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, - float *out, const float *in, - const uint8_t *__restrict__ mask) { - const float scale = 1.f / (1.f - ratio); - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 4 >= total_count) return; - - uint8_t m[4]; - - float4 *out4 = reinterpret_cast(out); - const float4 *in4 = reinterpret_cast(in); - const uint32_t *mask4 = reinterpret_cast(mask); - - uint32_t *m4 = reinterpret_cast(m); - m4[0] = mask4[i]; - - float4 input4 = in4[i]; - float4 res4; - res4.x = input4.x * scale * static_cast(m[0]); - res4.y = input4.y * scale * static_cast(m[1]); - res4.z = input4.z * scale * static_cast(m[2]); - res4.w = input4.w * scale * static_cast(m[3]); - out4[i] = res4; -} - -__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, - __half *out, const __half *in, - const uint8_t *__restrict__ mask) { - const __half scale = 1.f / (1.f - ratio); - - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 8 >= total_count) return; - - float4 *out4 = reinterpret_cast(out); - const float4 *vals_float4 = reinterpret_cast(in); - const uint64_t *mask8 = reinterpret_cast(mask); - - uint8_t m[8]; - uint64_t *m8 = reinterpret_cast(m); - m8[0] = mask8[i]; - - float4 val_float4 = vals_float4[i]; - float4 out_float4; - __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); - __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); - __half2 scale_mask_1 = - __halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1])); - __half2 scale_mask_2 = - __halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3])); - __half2 scale_mask_3 = - __halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5])); - __half2 scale_mask_4 = - __halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7])); - out_half2[0] = __hmul2(val_half2[0], scale_mask_1); - out_half2[1] = __hmul2(val_half2[1], scale_mask_2); - out_half2[2] = __hmul2(val_half2[2], scale_mask_3); - out_half2[3] = __hmul2(val_half2[3], scale_mask_4); - out4[i] = out_float4; -} - -template <> -void launch_ls_dropout(float *out, const float *vals, uint8_t *mask, - int total_count, float ratio, cudaStream_t stream, - bool backward) { - int grid_dim = total_count >> 12; - if (!backward) { - ls_dropout_kernel<<>>( - total_count, ratio, out, vals, mask, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count()); - } else { - ls_dropout_bwd_kernel<<>>(total_count, ratio, - out, vals, mask); - } -} - -template <> -void launch_ls_dropout<__half>(__half *out, const __half *vals, uint8_t *mask, - int total_count, float ratio, - cudaStream_t stream, bool backward) { - int grid_dim = total_count >> 13; - if (!backward) { - ls_dropout_kernel<<>>( - total_count, ratio, out, vals, mask, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count()); - } else { - ls_dropout_bwd_kernel<<>>(total_count, ratio, - out, vals, mask); - } -} - -/** - * @brief fused bias, dropout, and residual at the end of Attention and FFN, - * store dropped position in mask, it's not in-place - * - * @thread - * gridDim.x = total_count / 1024 - * blockDim.x = 1024 - * - * @param total_count total elements - * @param ratio drop ratio - * @param out [batch_size, seq_len, hidden_size], float and __half - * @param in [batch_size, seq_len, hidden_size], float and __half - * @param mask [batch_size, seq_len, hidden_size], uint8 type - * @param bias [hidden_size], ffn bias - * @param residual [batch_size, seq_len, hidden_size], float and __half - * @param seed seed to curand - * @param hidden_size hidden size - * @return void - */ -__global__ void ls_dropout_res_bias_kernel( - const int total_count, const float ratio, float *__restrict__ out, - const float *__restrict__ in, uint8_t *__restrict__ mask, - const float *__restrict__ bias, const float *__restrict__ residual, - const int seed, const int hidden_size) { - const float scale = 1.f / (1.f - ratio); - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 4 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - uint8_t m[4]; - - float4 *out4 = reinterpret_cast(out); - const float4 *data4 = reinterpret_cast(in); - const float4 *residual4 = reinterpret_cast(residual); - const float4 *bias4 = reinterpret_cast(bias); - uint32_t *mask4 = reinterpret_cast(mask); - float4 rand = curand_uniform4(&state); - - m[0] = static_cast(rand.x > ratio); - m[1] = static_cast(rand.y > ratio); - m[2] = static_cast(rand.z > ratio); - m[3] = static_cast(rand.w > ratio); - - int bias_i = i % (hidden_size >> 2); - uint32_t *m4 = reinterpret_cast(m); - mask4[i] = m4[0]; - const float4 input4 = data4[i]; - const float4 b4 = __ldg(&bias4[bias_i]); - const float4 res4 = residual4[i]; - float4 output4; - - output4.x = (input4.x + b4.x) * scale * m[0] + res4.x; - output4.y = (input4.y + b4.y) * scale * m[1] + res4.y; - output4.z = (input4.z + b4.z) * scale * m[2] + res4.z; - output4.w = (input4.w + b4.w) * scale * m[3] + res4.w; - - out4[i] = output4; -} - -__global__ void ls_dropout_res_bias_kernel( - const int total_count, const float ratio, __half *__restrict__ out, - const __half *__restrict__ in, uint8_t *__restrict__ mask, - const __half *__restrict__ bias, const __half *__restrict__ residual, - const int seed, const int hidden_size) { - const __half scale = 1. / (1. - ratio); - - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 8 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - - const float4 *vals_float4 = reinterpret_cast(in); - float4 *outs_float4 = reinterpret_cast(out); - const float4 *residual4 = reinterpret_cast(residual); - const float4 *bias4 = reinterpret_cast(bias); - uint64_t *mask8 = reinterpret_cast(mask); - - uint8_t m[8]; - float4 rand = curand_uniform4(&state); - m[0] = static_cast(rand.x > ratio); - m[1] = static_cast(rand.y > ratio); - m[2] = static_cast(rand.z > ratio); - m[3] = static_cast(rand.w > ratio); - rand = curand_uniform4(&state); - m[4] = static_cast(rand.x > ratio); - m[5] = static_cast(rand.y > ratio); - m[6] = static_cast(rand.z > ratio); - m[7] = static_cast(rand.w > ratio); - uint64_t *m8 = reinterpret_cast(m); - mask8[i] = m8[0]; - - int bias_i = i % (hidden_size >> 3); - float4 val_float4 = vals_float4[i]; - const float4 b4 = __ldg(&bias4[bias_i]); - const float4 res4 = residual4[i]; - float4 out_float4; - - __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); - __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); - const __half2 *b_half2 = reinterpret_cast(&b4); - const __half2 *res_half2 = reinterpret_cast(&res4); - __half2 scale_mask_1 = - __halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1])); - __half2 scale_mask_2 = - __halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3])); - __half2 scale_mask_3 = - __halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5])); - __half2 scale_mask_4 = - __halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7])); - out_half2[0] = - __hfma2(__hadd2(val_half2[0], b_half2[0]), scale_mask_1, res_half2[0]); - out_half2[1] = - __hfma2(__hadd2(val_half2[1], b_half2[1]), scale_mask_2, res_half2[1]); - out_half2[2] = - __hfma2(__hadd2(val_half2[2], b_half2[2]), scale_mask_3, res_half2[2]); - out_half2[3] = - __hfma2(__hadd2(val_half2[3], b_half2[3]), scale_mask_4, res_half2[3]); - outs_float4[i] = out_float4; -} - -template <> -void launch_ls_dropout_res_bias(float *out, const float *vals, - uint8_t *mask, const float *bias, - const float *residual, int total_count, - int dim, float ratio, - cudaStream_t stream) { - int grid_dim = total_count >> 12; - ls_dropout_res_bias_kernel<<>>( - total_count, ratio, out, vals, mask, bias, residual, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -template <> -void launch_ls_dropout_res_bias<__half>(__half *out, const __half *vals, - uint8_t *mask, const __half *bias, - const __half *residual, int total_count, - int dim, float ratio, - cudaStream_t stream) { - int grid_dim = total_count >> 13; - ls_dropout_res_bias_kernel<<>>( - total_count, ratio, out, vals, mask, bias, residual, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -/** - * @brief fused bias and dropout backward at the end of Attention and FFN - * - * @thread - * gridDim.x = hidden_size / 8 - * blockDim.x = 8 - * blockDim.y = 1024 / 8 = 128 - * - * @param row_size batch_size * seq_len - * @param ratio dropout ratio - * @param in_grad [batch_size, seq_len, hidden_size], input grad - * @param bias_grad [hidden_size], bias grad - * @param out_grad [batch_size, seq_len, hidden_size], output grad - * @param mask [batch_size, seq_len, hidden_size], dropout mask - * @param hidden_size - * @return void - */ -__global__ void ls_dropout_bias_bwd_kernel( - const int row_size, const float ratio, float *__restrict__ in_grad, - float *__restrict__ bias_grad, const float *__restrict__ out_grad, - const uint8_t *__restrict__ mask, const int hidden_size) { - const float scale = 1.f / (1.f - ratio); - // every block generate 8 bias result - __shared__ float tile[8][129]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8); - int stride = hidden_size * 128; - float local_sum = 0; - - int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); - for (int r = threadIdx.y; r < row_size; r += 128) { - float val = out_grad[idx]; - val *= scale * static_cast(mask[idx]); - local_sum += val; - in_grad[idx] = val; - idx += stride; - } - - tile[threadIdx.x][threadIdx.y] = local_sum; - __syncthreads(); - - float sum = 0; - int tid = threadIdx.y * blockDim.x + threadIdx.x; - int x = tid >> 7; - int y = tid & (127); - if (y < 32) { -#pragma unroll - for (int i = 0; i < 4; i++) { - sum += tile[x][y + i * 32]; - } - } - __syncthreads(); - - for (int i = 1; i < 32; i <<= 1) sum += g.shfl_down(sum, i); - - if (y == 0) tile[0][x] = sum; - __syncthreads(); - - if (threadIdx.x < 8) { - int pos = flat_2dim(blockIdx.x, threadIdx.x, 8); - bias_grad[pos] = tile[0][threadIdx.x]; - } -} - -__global__ void ls_dropout_bias_bwd_kernel( - const int row_size, const float ratio, __half *__restrict__ in_grad, - __half *__restrict__ bias_grad, const __half *__restrict__ out_grad, - const uint8_t *__restrict__ mask, const int hidden_size) { - const __half2 scale = __float2half2_rn(1.f / (1.f - ratio)); - __shared__ __half2 tile[8][129]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad); - const __half2 *out_grad2 = reinterpret_cast(out_grad); - __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad); - - int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8); - int stride = hidden_size * 128; - __half2 local_sum = __float2half2_rn(0.f); - - int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); - for (int r = threadIdx.y; r < row_size; r += 128) { - __half2 val = out_grad2[idx]; - __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]); - val *= scale * m2; - local_sum += val; - in_grad2[idx] = val; - idx += stride; - } - - tile[threadIdx.x][threadIdx.y] = local_sum; - __syncthreads(); - - __half2 sum = __float2half2_rn(0.f); - int tid = threadIdx.y * blockDim.x + threadIdx.x; - int x = tid >> 7; - int y = tid & (127); - if (y < 32) { -#pragma unroll - for (int i = 0; i < 4; i++) { - sum += tile[x][y + i * 32]; - } - } - __syncthreads(); - - for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); - - if (y == 0) tile[0][x] = sum; - __syncthreads(); - - if (threadIdx.x < 8) { - int pos = flat_2dim(blockIdx.x, threadIdx.x, 8); - bias_grad2[pos] = tile[0][threadIdx.x]; - } -} - -template -void launch_ls_dropout_bias_bwd(T *in_grad, T *bias_grad, const T *out_grad, - const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream) { - dim3 grid_dim((dim - 1) / 8 + 1); - dim3 block_dim(8, 128); - ls_dropout_bias_bwd_kernel<<>>( - row_size, ratio, in_grad, bias_grad, out_grad, mask, dim); -} - -template <> -void launch_ls_dropout_bias_bwd(__half *in_grad, __half *bias_grad, - const __half *out_grad, const uint8_t *mask, - int row_size, int dim, float ratio, - cudaStream_t stream) { - dim >>= 1; - dim3 grid_dim((dim - 1) / 8 + 1); - dim3 block_dim(8, 128); - ls_dropout_bias_bwd_kernel<<>>( - row_size, ratio, in_grad, bias_grad, out_grad, mask, dim); -} - -template void launch_ls_dropout_bias_bwd(float *in_grad, float *bias_grad, - const float *out_grad, - const uint8_t *mask, int row_size, - int dim, float ratio, - cudaStream_t stream); - -/** - * @brief fused bias, activation, and dropout at the end of first ffn - * - * @thread - * gridDim.x = hidden_size / 8 - * blockDim.x = 8 - * blockDim.y = 1024 / 8 = 128 - * - * @tparam act_type activation function, like kRelu, kGelu - * @param total_count total elements - * @param ratio drop ratio - * @param out [batch_size, seq_len, hidden_size], float and __half - * @param in [batch_size, seq_len, hidden_size], float and __half - * @param mask [batch_size, seq_len, hidden_size], uint8 type - * @param bias [hidden_size], ffn bias - * @param seed seed to curand - * @param hidden_size - * @return void - */ -template -__global__ void ls_dropout_act_bias_kernel( - const int total_count, const float ratio, float *__restrict__ out, - const float *__restrict__ in, uint8_t *__restrict__ mask, - const float *__restrict__ bias, const int seed, const int hidden_size) { - const float scale = 1.f / (1.f - ratio); - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 4 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - uint8_t m[4]; - - float4 *out4 = reinterpret_cast(out); - const float4 *data4 = reinterpret_cast(in); - const float4 *bias4 = reinterpret_cast(bias); - uint32_t *mask4 = reinterpret_cast(mask); - float4 rand = curand_uniform4(&state); - - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - - int bias_i = i % (hidden_size >> 2); - uint32_t *m4 = reinterpret_cast(m); - mask4[i] = m4[0]; - const float4 input4 = data4[i]; - const float4 b4 = __ldg(&bias4[bias_i]); - float4 output4; - - output4.x = - activation_kernel(input4.x + b4.x) * scale * m[0]; - output4.y = - activation_kernel(input4.y + b4.y) * scale * m[1]; - output4.z = - activation_kernel(input4.z + b4.z) * scale * m[2]; - output4.w = - activation_kernel(input4.w + b4.w) * scale * m[3]; - - out4[i] = output4; -} - -template -__global__ void ls_dropout_act_bias_kernel( - const int total_count, const float ratio, __half *__restrict__ out, - const __half *__restrict__ in, uint8_t *__restrict__ mask, - const __half *__restrict__ bias, const int seed, const int hidden_size) { - const float scale = 1.f / (1.f - ratio); - - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 8 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - - const float4 *vals_float4 = reinterpret_cast(in); - float4 *outs_float4 = reinterpret_cast(out); - const float4 *bias4 = reinterpret_cast(bias); - uint64_t *mask8 = reinterpret_cast(mask); - - uint8_t m[8]; - float4 rand = curand_uniform4(&state); - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - rand = curand_uniform4(&state); - m[4] = (uint8_t)(rand.x > ratio); - m[5] = (uint8_t)(rand.y > ratio); - m[6] = (uint8_t)(rand.z > ratio); - m[7] = (uint8_t)(rand.w > ratio); - uint64_t *m8 = reinterpret_cast(m); - mask8[i] = *m8; - - int bias_i = i % (hidden_size >> 3); - float4 val_float4 = vals_float4[i]; - const float4 b4 = __ldg(&bias4[bias_i]); - float4 out_float4; - - __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); - __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); - const __half2 *b_half2 = reinterpret_cast(&b4); - - __half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]); - __half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]); - __half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]); - __half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]); - out_half2[0] = __hmul2( - activation_kernel(__hadd2(val_half2[0], b_half2[0])), - scale_mask_1); - out_half2[1] = __hmul2( - activation_kernel(__hadd2(val_half2[1], b_half2[1])), - scale_mask_2); - out_half2[2] = __hmul2( - activation_kernel(__hadd2(val_half2[2], b_half2[2])), - scale_mask_3); - out_half2[3] = __hmul2( - activation_kernel(__hadd2(val_half2[3], b_half2[3])), - scale_mask_4); - outs_float4[i] = out_float4; -} - -template <> -void launch_ls_dropout_act_bias( - float *out, const float *vals, uint8_t *mask, const float *bias, - int total_count, int dim, float ratio, cudaStream_t stream) { - int grid_dim = total_count >> 10; - ls_dropout_act_bias_kernel - <<>>( - total_count, ratio, out, vals, mask, bias, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -template <> -void launch_ls_dropout_act_bias( - __half *out, const __half *vals, uint8_t *mask, const __half *bias, - int total_count, int dim, float ratio, cudaStream_t stream) { - int grid_dim = total_count >> 11; - ls_dropout_act_bias_kernel - <<>>( - total_count, ratio, out, vals, mask, bias, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -template <> -void launch_ls_dropout_act_bias( - float *out, const float *vals, uint8_t *mask, const float *bias, - int total_count, int dim, float ratio, cudaStream_t stream) { - int grid_dim = total_count >> 10; - ls_dropout_act_bias_kernel - <<>>( - total_count, ratio, out, vals, mask, bias, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -template <> -void launch_ls_dropout_act_bias( - __half *out, const __half *vals, uint8_t *mask, const __half *bias, - int total_count, int dim, float ratio, cudaStream_t stream) { - int grid_dim = total_count >> 11; - ls_dropout_act_bias_kernel - <<>>( - total_count, ratio, out, vals, mask, bias, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -/** - * @brief fused bias, activation, and dropout backward - * - * @thread - * gridDim.x = total_count / 1024 - * blockDim.x = 1024 - * - * @tparam act_type kRelu - * @param row_size batch_size * seq_len - * @param ratio dropout ratio - * @param in_grad [batch_size, seq_len, hidden_size], input grad - * @param bias_grad [hidden_size], bias grad - * @param out_grad [batch_size, seq_len, hidden_size], output grad - * @param mask [batch_size, seq_len, hidden_size], dropout mask - * @param hidden_size - * @return void - */ -template -__global__ void ls_dropout_act_bias_bwd_kernel( - const int row_size, const float ratio, T *in_grad, - T *__restrict__ bias_grad, const T *__restrict__ input, - const T *__restrict__ bias, const T *out_grad, - const uint8_t *__restrict__ mask, const int hidden_size) { - const float scale = 1.f / (1.f - ratio); - __shared__ float tile[WARP_SIZE][WARP_SIZE + 1]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); - - int stride = hidden_size * WARP_SIZE; - float local_sum = 0; - - int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); - if (col_idx < hidden_size) { - for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) { - float val = out_grad[idx]; - float in = input[idx]; - float b = bias[idx % hidden_size]; - val = activation_bwd_kernel( - val * scale * static_cast(mask[idx]), in + b); - local_sum += val; - in_grad[idx] = val; - idx += stride; - } - } - - tile[threadIdx.x][threadIdx.y] = local_sum; - __syncthreads(); - float sum = tile[threadIdx.y][threadIdx.x]; - __syncthreads(); - - for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); - - if (threadIdx.x == 0) tile[0][threadIdx.y] = sum; - __syncthreads(); - - if (threadIdx.y == 0) { - int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); - bias_grad[pos] = tile[0][threadIdx.x]; - } -} - -// @brief fused bias, activation, and dropout backward -// It is deprecated for precision reason. Keep it for future optimization. -// -// template -// __global__ void ls_dropout_act_bias_bwd_kernel( -// const int row_size, const float ratio, __half * in_grad, -// __half *__restrict__ bias_grad, const __half *__restrict__ input, const -// __half *__restrict__ bias, const __half * out_grad, const uint8_t -// *__restrict__ mask, const int hidden_size) { -// const __half2 scale = __float2half2_rn(1.f / (1.f - ratio)); -// __shared__ __half2 tile[WARP_SIZE][WARP_SIZE + 1]; - -// cg::thread_block b = cg::this_thread_block(); -// cg::thread_block_tile g = cg::tiled_partition(b); - -// __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad); -// __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad); -// const __half2 *out_grad2 = reinterpret_cast(out_grad); -// const __half2 *input2 = reinterpret_cast(input); -// const __half2 *bias2 = reinterpret_cast(bias); - -// int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); - -// int stride = hidden_size * WARP_SIZE; -// __half2 local_sum = __float2half2_rn(0.f); - -// int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); -// if (col_idx < hidden_size) { -// for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) { -// __half2 val = out_grad2[idx]; -// __half2 in2 = input2[idx]; -// __half2 b2 = bias2[idx % hidden_size ]; -// __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]); -// val = activation_bwd_kernel(val * scale -// * -// m2, -// in2+b2); -// local_sum += val; -// in_grad2[idx] = val; -// idx += stride; -// } -// } - -// tile[threadIdx.x][threadIdx.y] = local_sum; -// __syncthreads(); -// __half2 sum = tile[threadIdx.y][threadIdx.x]; -// __syncthreads(); - -// for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); - -// if (threadIdx.x == 0) tile[0][threadIdx.y] = sum; -// __syncthreads(); - -// if (threadIdx.y == 0) { -// int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); -// bias_grad2[pos] = tile[0][threadIdx.x]; -// } -// } - -template -void launch_ls_dropout_act_bias_bwd(T *in_grad, T *bias_grad, const T *input, - const T *bias, const T *out_grad, - const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream) { - dim3 grid_dim((dim - 1) / WARP_SIZE + 1); - dim3 block_dim(WARP_SIZE, WARP_SIZE); - ls_dropout_act_bias_bwd_kernel<<>>( - row_size, ratio, in_grad, bias_grad, input, bias, out_grad, mask, dim); -} - -// template <> -// void launch_ls_dropout_act_bias_bwd( -// __half *in_grad, __half *bias_grad,const __half *input, const __half -// *bias, const __half *out_grad, const uint8_t *mask, int row_size, int -// dim, float ratio, cudaStream_t stream) { -// dim >>= 1; -// dim3 grid_dim((dim - 1) / WARP_SIZE + 1); -// dim3 block_dim(WARP_SIZE, WARP_SIZE); -// ls_dropout_act_bias_bwd_kernel -// <<>>(row_size, ratio, in_grad, -// bias_grad, -// input, bias,out_grad, mask, dim); -// } - -template void launch_ls_dropout_act_bias_bwd( - float *in_grad, float *bias_grad, const float *input, const float *bias, - const float *out_grad, const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream); - -template void launch_ls_dropout_act_bias_bwd( - __half *in_grad, __half *bias_grad, const __half *input, const __half *bias, - const __half *out_grad, const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream); - -template void launch_ls_dropout_act_bias_bwd( - float *in_grad, float *bias_grad, const float *input, const float *bias, - const float *out_grad, const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream); - -template void launch_ls_dropout_act_bias_bwd( - __half *in_grad, __half *bias_grad, const __half *input, const __half *bias, - const __half *out_grad, const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream); diff --git a/colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu deleted file mode 100644 index 625b02cd2..000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu +++ /dev/null @@ -1,232 +0,0 @@ -#include - -#include "kernels.h" - -namespace cg = cooperative_groups; - -/** -@brief: fuse_transpose_bias -Calculate the sum of elements in each column of the matrix. - -@thread -gridDim.x = ceil(cols / WARP_SIZE) -blockDim.x = WARP_SIZE -blockDim.y = WARP_SIZE - -@param -inp: [rows, cols] -out: [cols] -rows: the number of rows in the matrix -cols: the number of cols in the matrix -*/ -template -__global__ void column_sum_reduce(const T *__restrict__ inp, - T *__restrict__ out, int rows, int cols) { - __shared__ float tile[WARP_SIZE][WARP_SIZE]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); - int y_stride = cols * WARP_SIZE; - float localSum = 0; - - // Loop across matrix row - // TODO: optimize to log complexity - if (idx < cols) { - int offset = flat_2dim(threadIdx.y, idx, cols); - for (int r = threadIdx.y; r < rows; r += WARP_SIZE) { - localSum += (float)inp[offset]; - offset += y_stride; - } - } - - // The sum of a row in tile is equal to the sum of a col in original matrix - tile[threadIdx.x][threadIdx.y] = localSum; - - __syncthreads(); - - // Sum the shared buffer. - // The change of threadIdx.x is continuous - float sum = tile[threadIdx.y][threadIdx.x]; - - __syncthreads(); - - // Calculate the sum of a row in tile - for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); - - if (threadIdx.x == 0) { - int pos = flat_2dim(blockIdx.x, threadIdx.y, WARP_SIZE); - if (pos < cols) out[pos] = sum; - } -} - -// [r, c] -> [c] -template <> -void launch_fuse_transpose_bias_kernel(const float *inp, float *out, - int rows, int cols, - cudaStream_t stream) { - dim3 grid_dim((cols - 1) / WARP_SIZE + 1); - dim3 block_dim(WARP_SIZE, WARP_SIZE); - - column_sum_reduce - <<>>(inp, out, rows, cols); -} - -template <> -void launch_fuse_transpose_bias_kernel<__half>(const __half *inp, __half *out, - int rows, int cols, - cudaStream_t stream) { - dim3 grid_dim((cols - 1) / WARP_SIZE + 1); - dim3 block_dim(WARP_SIZE, WARP_SIZE); - - column_sum_reduce<__half> - <<>>(inp, out, rows, cols); -} - -/** -@brief: fused_add2 -Add two matrix inp1 and inp2 to out. - -@thread -gridDim.x = batch_size * seq_len -blockDim.x = min(hidden_dim, MAX_THREADS) - -@param -inp1: [batch_size, seq_len, hidden_dim] -inp2: [batch_size, seq_len, hidden_dim] -out: [batch_size, seq_len, hidden_dim] -batch_size: the size of the current batch -seq_len: the sequence length of the current batch -hidden_dim: dim of the hidden tensor -*/ -template -__global__ void fused_add2_kernel(T *out, const T *inp1, const T *inp2, - int hidden_dim); - -template <> -__global__ void fused_add2_kernel(float *out, const float *inp1, - const float *inp2, int hidden_dim) { - int row_id = blockIdx.x; - int offset = flat_2dim(row_id, 0, hidden_dim); - - const float4 *inp1_4 = reinterpret_cast(inp1); - const float4 *inp2_4 = reinterpret_cast(inp2); - float4 *out_4 = reinterpret_cast(out); - float4 vinp1; - float4 vinp2; - float4 val; - - for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { - vinp1 = inp1_4[offset + i]; - vinp2 = inp2_4[offset + i]; - val.x = vinp1.x + vinp2.x; - val.y = vinp1.y + vinp2.y; - val.z = vinp1.z + vinp2.z; - val.w = vinp1.w + vinp2.w; - out_4[offset + i] = val; - } -} - -template <> -__global__ void fused_add2_kernel<__half>(__half *out, const __half *inp1, - const __half *inp2, int hidden_dim) { - int row_id = blockIdx.x; - int offset = flat_2dim(row_id, 0, hidden_dim); - - const float4 *inp1_4 = reinterpret_cast(inp1); - const float4 *inp2_4 = reinterpret_cast(inp2); - float4 *out_4 = reinterpret_cast(out); - float4 vinp1; - float4 vinp2; - float4 val; - __half2 *h2_inp1 = reinterpret_cast<__half2 *>(&vinp1); - __half2 *h2_inp2 = reinterpret_cast<__half2 *>(&vinp2); - __half2 *h2_val = reinterpret_cast<__half2 *>(&val); - - for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { - vinp1 = inp1_4[offset + i]; - vinp2 = inp2_4[offset + i]; - h2_val[0] = __hadd2(h2_inp1[0], h2_inp2[0]); - h2_val[1] = __hadd2(h2_inp1[1], h2_inp2[1]); - h2_val[2] = __hadd2(h2_inp1[2], h2_inp2[2]); - h2_val[3] = __hadd2(h2_inp1[3], h2_inp2[3]); - out_4[offset + i] = val; - } -} - -//[b, s, h] -> [b, s, h] -template <> -void launch_fused_add2(float *out, const float *inp1, const float *inp2, - int batch_size, int seq_len, int hidden_dim, - cudaStream_t &stream) { - hidden_dim >>= 2; - - dim3 grid_dim(batch_size * seq_len); - dim3 block_dim(min(hidden_dim, MAX_THREADS)); - - fused_add2_kernel<<>>(out, inp1, inp2, - hidden_dim); -} - -template <> -void launch_fused_add2<__half>(__half *out, const __half *inp1, - const __half *inp2, int batch_size, int seq_len, - int hidden_dim, cudaStream_t &stream) { - hidden_dim >>= 3; - - dim3 grid_dim(batch_size * seq_len); - dim3 block_dim(min(hidden_dim, MAX_THREADS)); - - fused_add2_kernel<<>>(out, inp1, inp2, - hidden_dim); -} - -template -__global__ void kernel_concat3_dim1(const T *inp1, const T *inp2, T *output, - int sz0, int sz2, int sz1_1, int sz1_2) { - int nele = sz0 * sz2 * (sz1_1 + sz1_2); - int idx = flat_2dim(blockIdx.x, threadIdx.x, blockDim.x); - if (idx >= nele) { - return; - } - float4 *dst_ptr = (float4 *)output + idx; - int idx2 = idx % sz2; - idx = idx / sz2; - int idx1 = idx % (sz1_1 + sz1_2); - int idx0 = idx / (sz1_1 + sz1_2); - float4 *src_ptr = nullptr; - int sz1 = 0; - if (idx1 < sz1_1) { - sz1 = sz1_1; - src_ptr = (float4 *)inp1; - } else { - idx1 -= sz1_1; - sz1 = sz1_2; - src_ptr = (float4 *)inp2; - } - src_ptr += flat_3dim(idx0, idx1, idx2, sz1, sz2); - dst_ptr[0] = src_ptr[0]; -} - -template <> -void launch_concat3_dim1(const float *inp1, const float *inp2, - float *output, int sz0, int sz2, int sz1_1, - int sz1_2, cudaStream_t stream) { - sz2 >>= 2; - int nele = sz0 * sz2 * (sz1_1 + sz1_2); - int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS; - kernel_concat3_dim1<<>>( - inp1, inp2, output, sz0, sz2, sz1_1, sz1_2); -} - -template <> -void launch_concat3_dim1<__half>(const __half *inp1, const __half *inp2, - __half *output, int sz0, int sz2, int sz1_1, - int sz1_2, cudaStream_t stream) { - sz2 >>= 3; - int nele = sz0 * sz2 * (sz1_1 + sz1_2); - int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS; - kernel_concat3_dim1<<>>( - inp1, inp2, output, sz0, sz2, sz1_1, sz1_2); -} diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/context.h b/colossalai/kernel/cuda_native/csrc/kernels/include/context.h deleted file mode 100644 index f7d75f38c..000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/context.h +++ /dev/null @@ -1,36 +0,0 @@ -#pragma once - -#include -#include - -#include -#include - -#include "cuda_util.h" - -class Context { - public: - Context() : _stream(nullptr) { - CHECK_GPU_ERROR(cublasCreate(&_cublasHandle)); - } - - virtual ~Context() {} - - static Context &Instance() { - static Context _ctx; - return _ctx; - } - - void set_stream(cudaStream_t stream) { - _stream = stream; - CHECK_GPU_ERROR(cublasSetStream(_cublasHandle, _stream)); - } - - cudaStream_t get_stream() { return _stream; } - - cublasHandle_t get_cublashandle() { return _cublasHandle; } - - private: - cudaStream_t _stream; - cublasHandle_t _cublasHandle; -}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/cross_entropy_layer.h b/colossalai/kernel/cuda_native/csrc/kernels/include/cross_entropy_layer.h deleted file mode 100644 index f4e9befc6..000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/cross_entropy_layer.h +++ /dev/null @@ -1,46 +0,0 @@ -#pragma once - -#include -#include -#include - -#include - -#include "cuda_util.h" - -template -class CrossEntropyLayer { - public: - CrossEntropyLayer(float epsilon, int padding_idx, int max_batch_tokens); - - virtual ~CrossEntropyLayer(); - - void Forward(const T *inputs_ptr, const int *targets_ptr, float *outputs_ptr, - float *nll_loss_ptr); - - void Backward(const float *grad_outputs_ptr, const T *inputs_ptr, - const int *targets_ptr, T *grad_inputs_ptr); - - void set_cur_batch_shape(int batch_size, int seq_len, int vocab_size); - - private: - void allocate_mem_buffer() { - // allocate local gpu memory - _loss_buffer = cuda_malloc(_max_batch_tokens * 2); - } - - void free_mem_buffer() { - // free local gpu memory - cuda_free(_loss_buffer); - } - - const int _padding_idx; - const float _epsilon; - const int _max_batch_tokens; - - size_t _batch_size; - size_t _seq_len; - size_t _vocab_size; - - float *_loss_buffer; -}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/cublas_wrappers.h b/colossalai/kernel/cuda_native/csrc/kernels/include/cublas_wrappers.h deleted file mode 100644 index 90255152b..000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/cublas_wrappers.h +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright 2021 The LightSeq Team - Copyright Microsoft DeepSpeed - This file is adapted from Microsoft DeepSpeed - Licensed under the MIT License. -*/ -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, - const float *alpha, const float *beta, const float *A, - const float *B, float *C, - cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT); - -int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, - const float *alpha, const float *beta, const __half *A, - const __half *B, __half *C, - cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP); - -int cublas_strided_batched_gemm(cublasHandle_t handle, int m, int n, int k, - const float *alpha, const float *beta, - const float *A, const float *B, float *C, - cublasOperation_t op_A, cublasOperation_t op_B, - int stride_A, int stride_B, int stride_C, - int batch, - cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT); - -int cublas_strided_batched_gemm( - cublasHandle_t handle, int m, int n, int k, const float *alpha, - const float *beta, const __half *A, const __half *B, __half *C, - cublasOperation_t op_A, cublasOperation_t op_B, int stride_A, int stride_B, - int stride_C, int batch, - cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP); diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h b/colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h deleted file mode 100644 index 1595257be..000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h +++ /dev/null @@ -1,34 +0,0 @@ -#pragma once - -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -template -void check_gpu_error(T result, char const *const func, const char *const file, - int const line); - -#define CHECK_GPU_ERROR(val) check_gpu_error((val), #val, __FILE__, __LINE__) - -template -void print_vec(const T *outv, std::string outn, int num_output_ele); - -template -T *cuda_malloc(size_t ele_num); - -void cuda_free(void *pdata); - -template -void check_nan_inf(const T *data_ptr, int dsize, bool check_nan_inf, - std::string file, int line, cudaStream_t stream); - -#define CHECK_NAN_INF(ptr, size, stream) \ - check_nan_inf((ptr), (size), true, __FILE__, __LINE__, (stream)); \ - check_nan_inf((ptr), (size), false, __FILE__, __LINE__, (stream)) diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h b/colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h deleted file mode 100644 index 025fbf3f8..000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h +++ /dev/null @@ -1,96 +0,0 @@ -#pragma once - -#include -#include -#include - -#include - -#include "kernels.h" - -template -class Dropout { - public: - struct Config { - float ratio; - bool training; - - Config(float r) : ratio(r), training(true) {} - float RATIO() const { return training ? ratio : 0.0; } - }; - - Dropout(const Config &config, size_t max_ele_num) - : _config(config), _mask(nullptr) { - _mask = cuda_malloc(max_ele_num); - } - - virtual ~Dropout() { cuda_free(_mask); } - - // after attention softmax - void dropout(T *output, const T *input, int count, cudaStream_t stream, - bool bwd = false) { - launch_ls_dropout(output, input, _mask, count, _config.RATIO(), stream, - bwd); - } - - void d_dropout(T *d_inp_out, int count, cudaStream_t stream) { - launch_ls_dropout(d_inp_out, d_inp_out, _mask, count, _config.RATIO(), - stream, true); - } - - // transformer layer's postprocessing dropout, after attn or ffn module, - // before residual add. - void bias_dropout_residual(T *output, const T *input, const T *residual, - const T *bias, int rows, int cols, - cudaStream_t stream) { - launch_ls_dropout_res_bias(output, input, _mask, bias, residual, - rows * cols, cols, _config.RATIO(), stream); - } - - void d_bias_dropout_residual(T *d_input, T *d_bias, const T *d_output, - int rows, int cols, cudaStream_t stream) { - launch_ls_dropout_bias_bwd(d_input, d_bias, d_output, _mask, rows, cols, - _config.RATIO(), stream); - } - - // dropout inside ffn. - void bias_act_dropout(T *output, const T *input, const T *bias, int rows, - int cols, std::string activation_fn, - cudaStream_t stream) { - if (activation_fn == "relu") { - launch_ls_dropout_act_bias( - output, input, _mask, bias, rows * cols, cols, _config.RATIO(), - stream); - } else if (activation_fn == "gelu") { - launch_ls_dropout_act_bias( - output, input, _mask, bias, rows * cols, cols, _config.RATIO(), - stream); - } else { - throw std::runtime_error("not supported activation: " + activation_fn); - } - } - - void d_bias_act_dropout(T *d_inp_out, T *d_bias_out, const T *input, - const T *bias, int rows, int cols, - std::string activation_fn, cudaStream_t stream) { - if (activation_fn == "relu") { - launch_ls_dropout_act_bias_bwd( - d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols, - _config.RATIO(), stream); - } else if (activation_fn == "gelu") { - launch_ls_dropout_act_bias_bwd( - d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols, - _config.RATIO(), stream); - } else { - throw std::runtime_error("not supported activation: " + activation_fn); - } - } - - bool HasDropout() const { return _config.RATIO() > 0.0; } - - void SetTrainingMode(bool training) { _config.training = training; } - - private: - uint8_t *_mask; - Config _config; -}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h b/colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h deleted file mode 100644 index 8186da1ee..000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h +++ /dev/null @@ -1,69 +0,0 @@ -#pragma once - -/* Copyright 2021 The LightSeq Team - Copyright Microsoft DeepSpeed - This file is adapted from Microsoft DeepSpeed - Licensed under the MIT License. -*/ -#include -#include -#include - -#include - -#include "cublas_wrappers.h" -#include "kernels.h" - -template -class FeedForward { - public: - struct Config { - int outputSize; - int inputSize; - std::array gemm_algos; - Config(int outputs, int inputs) - : outputSize(outputs), - inputSize(inputs), - gemm_algos(std::array({99, 99, 99})) {} - }; - - FeedForward(Config config) : config_(config) {} - - ~FeedForward() {} - - void Forward(int bsz, const T *input_ptr, const T *weights, T *out, - cublasHandle_t &_cublasHandle) { - float alpha = T(1.); - float beta = T(0.); - - cublas_gemm_ex(_cublasHandle, CUBLAS_OP_T, CUBLAS_OP_N, config_.outputSize, - bsz, config_.inputSize, &alpha, &beta, weights, input_ptr, - out, cublasGemmAlgo_t(config_.gemm_algos[0])); - } - void Backward(int bsz, const T *out_grad, const T *input_ptr, - const T *weights, T *weights_grad, T *bias_grad, - cublasHandle_t &_cublasHandle, cudaStream_t &stream, - T *inp_grad_out = nullptr, T *out_grad_trans_out = nullptr, - bool compute_bias = true) { - float alpha = (T)1.0, beta = (T)0.0; - cublas_gemm_ex(_cublasHandle, CUBLAS_OP_N, CUBLAS_OP_T, config_.inputSize, - config_.outputSize, bsz, &alpha, &beta, input_ptr, out_grad, - weights_grad, cublasGemmAlgo_t(config_.gemm_algos[1])); - - cublas_gemm_ex(_cublasHandle, CUBLAS_OP_N, CUBLAS_OP_N, config_.inputSize, - bsz, config_.outputSize, &alpha, &beta, weights, out_grad, - inp_grad_out, cublasGemmAlgo_t(config_.gemm_algos[2])); - if (compute_bias) { - launch_fuse_transpose_bias_kernel(out_grad, bias_grad, bsz, - config_.outputSize, stream); - } - } - - void reset_size(int outputSize, int inputSize) { - config_.outputSize = outputSize; - config_.inputSize = inputSize; - } - - private: - Config config_; -}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/kernels.h b/colossalai/kernel/cuda_native/csrc/kernels/include/kernels.h deleted file mode 100644 index 735e1363c..000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/kernels.h +++ /dev/null @@ -1,275 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include - -#include - -#define MAX_THREADS 1024 -#define WARP_SIZE 32 - -enum class ActivationType { kRelu, kGelu }; - -void launch_curand_init(int total_count, int dim, cudaStream_t stream); - -template -void launch_layer_norm(T *ln_res, T *vars, T *means, const T *inp, - const T *scale, const T *bias, int batch_size, - int hidden_dim, cudaStream_t stream); - -template -void launch_ln_bw(T *gamma_grad, T *betta_grad, T *inp_grad, const T *out_grad, - const T *residual_grad, const T *inp_or_out, const T *gamma, - const T *betta, const T *vars, const T *means, int batch, - int hidden_dim, cudaStream_t stream[2]); - -template -void launch_attn_softmax(T *vals, const T *attn_mask, int batch_size, int heads, - int from_len, int to_len, bool mask_future, - cudaStream_t stream); - -template -void launch_attn_softmax_bw(T *out_grad, const T *soft_inp, int rows, - int softmax_len, cudaStream_t stream); - -// [b, s, h] -> [b, nh, s, ad] -template -void launch_transform_0213(T *output, const T *vals, int batch_size, - int seq_length, int hidden_dim, int nhead, - cudaStream_t stream); - -// [b, s, 3, h] -> [3, b, nh, s, ad] -template -void launch_bias_add_transform_20314(T *output, const T *input, const T *bias, - int dim_0, int dim_1, int dim_2, int dim_3, - int dim_4, cudaStream_t stream); - -// [tc, b, nh, s, ad] -> [b, s, tc, nh, ad] -template -void launch_transform4d_0213(T *output, const T *vals, int batch_size, - int seq_len, int hidden_dim, int nhead, - int trans_count, cudaStream_t stream); - -template -void launch_ls_dropout(T *out, const T *vals, uint8_t *mask, int total_count, - float ratio, cudaStream_t stream, bool backward = false); - -template -void launch_ls_dropout_res_bias(T *out, const T *vals, uint8_t *mask, - const T *bias, const T *residual, - int total_count, int dim, float ratio, - cudaStream_t stream); - -template -void launch_ls_dropout_act_bias(T *out, const T *vals, uint8_t *mask, - const T *bias, int total_count, int dim, - float ratio, cudaStream_t stream); - -template -void launch_ls_dropout_bias_bwd(T *in_grad, T *bias_grad, const T *out_grad, - const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream); - -template -void launch_ls_dropout_act_bias_bwd(T *in_grad, T *bias_grad, const T *input, - const T *bias, const T *out_grad, - const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream); - -template -void launch_fuse_transpose_bias_kernel(const T *inp, T *out, int rows, int cols, - cudaStream_t stream); - -void launch_param_update(const float *input, __half *output, int size, - cudaStream_t stream); - -template -void launch_concat3_dim1(const T *inp1, const T *inp2, T *output, int sz0, - int sz2, int sz1_1, int sz1_2, cudaStream_t stream); - -template -void launch_fused_add2(T *out, const T *inp1, const T *inp2, int batch_size, - int seq_len, int hidden_size, cudaStream_t &stream); - -template -void launch_cross_entropy_fw(const T *inputs_ptr, const int *targets_ptr, - float *outputs_ptr, float *nll_loss_ptr, - float *loss_buffer, const int padding_idx, - const float epsilon, const int batch_size, - const int seq_len, const int vocab_size, - cudaStream_t stream); - -template -void launch_cross_entropy_bw(const float *grad_outputs_ptr, const T *inputs_ptr, - const int *targets_ptr, T *grad_inputs_ptr, - const int padding_idx, const float epsilon, - const int batch_size, const int seq_len, - const int vocab_size, cudaStream_t stream); - -template -void launch_lookup_scale_pos_dropout( - T *output, const int *input, const T *embeddings, const T *pos_embeddings, - uint8_t *dropout_mask, int batch_size, int seq_len, int embedding_dim, - int padding_idx, float dropout_ratio, int step, cudaStream_t &stream); - -template -void launch_d_lookup_scale_pos_dropout( - T *grad_embeddings, const T *grad_output, const int *input, - const uint8_t *dropout_mask, int batch_size, int seq_len, int embedding_dim, - int vocab_size, int padding_idx, float dropout_ratio, cudaStream_t &stream); - -/* Convert 2-dim tensor index into vector index */ -__forceinline__ __host__ __device__ int flat_2dim(int id1, int id2, int dim2) { - return id1 * dim2 + id2; -} - -/* Convert 3-dim tensor index into vector index */ -__forceinline__ __host__ __device__ int flat_3dim(int id1, int id2, int id3, - int dim2, int dim3) { - return id1 * dim2 * dim3 + id2 * dim3 + id3; -} - -/* Convert 4-dim tensor index into vector index */ -__forceinline__ __host__ __device__ int flat_4dim(int id1, int id2, int id3, - int id4, int dim2, int dim3, - int dim4) { - // return id1*(dim2*dim3*dim4) + id2*(dim3*dim4) + id3*dim4 + id4; - int res = id4; - - int ld = dim4; - res += id3 * ld; - - ld *= dim3; - res += id2 * ld; - - ld *= dim2; - res += id1 * ld; - - return res; -} - -/* Convert 5-dim tensor index into vector index */ -__forceinline__ __host__ __device__ int flat_5dim(int id1, int id2, int id3, - int id4, int id5, int dim2, - int dim3, int dim4, - int dim5) { - // return id1*(dim2*dim3*dim4*dim5) + id2*(dim3*dim4*dim5) + id3*(dim4*dim5) + - // id4*dim5 + dim5; - int res = id5; - - int ld = dim5; - res += id4 * ld; - - ld *= dim4; - res += id3 * ld; - - ld *= dim3; - res += id2 * ld; - - ld *= dim2; - res += id1 * ld; - - return res; -} - -/* Convert 6-dim tensor index into vector index */ -__forceinline__ __host__ __device__ int flat_6dim(int id1, int id2, int id3, - int id4, int id5, int id6, - int dim2, int dim3, int dim4, - int dim5, int dim6) { - // return id1*(dim2*dim3*dim4*dim5*dim6) + id2*(dim3*dim4*dim5*dim6) + - // id3*(dim4*dim5*dim6) + id4*(dim5*dim6) + id5*dim6 + id6; - int res = id6; - - int ld = dim6; - res += id5 * ld; - - ld *= dim5; - res += id4 * ld; - - ld *= dim4; - res += id3 * ld; - - ld *= dim3; - res += id2 * ld; - - ld *= dim2; - res += id1 * ld; - - return res; -} - -/* Convert vector index to 6-dim tensor index */ -__forceinline__ __host__ __device__ void decompose_6dim( - int src, int dim1, int dim2, int dim3, int dim4, int dim5, int *id0, - int *id1, int *id2, int *id3, int *id4, int *id5) { - *id5 = src % dim5; - src /= dim5; - - *id4 = src % dim4; - src /= dim4; - - *id3 = src % dim3; - src /= dim3; - - *id2 = src % dim2; - src /= dim2; - - *id1 = src % dim1; - *id0 = src / dim1; -} - -/* Convert vector index to 5-dim tensor index */ -__forceinline__ __host__ __device__ void decompose_5dim(int src, int dim1, - int dim2, int dim3, - int dim4, int *id0, - int *id1, int *id2, - int *id3, int *id4) { - *id4 = src % dim4; - src /= dim4; - - *id3 = src % dim3; - src /= dim3; - - *id2 = src % dim2; - src /= dim2; - - *id1 = src % dim1; - *id0 = src / dim1; -} - -/* Convert vector index to 4-dim tensor index */ -__forceinline__ __host__ __device__ void decompose_4dim(int src, int dim1, - int dim2, int dim3, - int *id0, int *id1, - int *id2, int *id3) { - *id3 = src % dim3; - src /= dim3; - - *id2 = src % dim2; - src /= dim2; - - *id1 = src % dim1; - *id0 = src / dim1; -} - -/* Convert vector index to 3-dim tensor index */ -__forceinline__ __host__ __device__ void decompose_3dim(int src, int dim1, - int dim2, int *id0, - int *id1, int *id2) { - *id2 = src % dim2; - src /= dim2; - - *id1 = src % dim1; - *id0 = src / dim1; -} - -/* Convert vector index to 2-dim tensor index */ -__forceinline__ __host__ __device__ void decompose_2dim(int src, int dim1, - int *id0, int *id1) { - *id1 = src % dim1; - *id0 = src / dim1; -} diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/ls_cub.cuh b/colossalai/kernel/cuda_native/csrc/kernels/include/ls_cub.cuh deleted file mode 100644 index 4f65e7b54..000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/ls_cub.cuh +++ /dev/null @@ -1,12 +0,0 @@ -// copied from https://github.com/dmlc/dgl/pull/2758 -#ifndef DGL_ARRAY_CUDA_DGL_CUB_CUH_ -#define DGL_ARRAY_CUDA_DGL_CUB_CUH_ - -#define CUB_NS_PREFIX namespace ls { -#define CUB_NS_POSTFIX } -#include "cub/cub.cuh" -#include "cub/util_allocator.cuh" -#undef CUB_NS_POSTFIX -#undef CUB_NS_PREFIX - -#endif diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/normalize_layer.h b/colossalai/kernel/cuda_native/csrc/kernels/include/normalize_layer.h deleted file mode 100644 index a7767e187..000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/normalize_layer.h +++ /dev/null @@ -1,65 +0,0 @@ -#pragma once - -#include -#include -#include - -#include - -#include "kernels.h" - -using namespace std; - -template -class Normalize_Layer { - public: - struct Config { - uint32_t hidden_dim; - bool use_mean; - Config(uint32_t hidden_dim, bool use_mean = false) - : hidden_dim(hidden_dim), use_mean(use_mean) {} - }; - - Normalize_Layer(Config config, size_t max_rows) - : config_(config), vars_(nullptr), means_(nullptr) { - vars_ = cuda_malloc(max_rows); - if (config_.use_mean) { - means_ = cuda_malloc(max_rows); - } - } - - ~Normalize_Layer() { - cuda_free(vars_); - cuda_free(means_); - } - - void Forward(T *ln_res, const T *inp, const T *gamma, const T *betta, - int batch_size, cudaStream_t stream) { - launch_layer_norm(ln_res, vars_, means_, inp, gamma, betta, batch_size, - config_.hidden_dim, stream); - } - - /* - residual_grad, inp_or_out, betta should be treated carefully. - inp_or_out = input if use_mean else output - residual_grad, betta can be nullptr. - residual_grad will be added to dinp if it is not nullptr - which is useful in transformer layer when pre-ln - betta are only used to compute xhat, - (use_mean == false) ^ (betta == nullptr) should be true - */ - void Backward(T *gamma_grad, T *betta_grad, T *inp_grad, const T *out_grad, - const T *residual_grad, const T *inp_or_out, const T *gamma, - const T *betta, int batch_size, cudaStream_t stream[2]) { - launch_ln_bw(gamma_grad, betta_grad, inp_grad, out_grad, residual_grad, - inp_or_out, gamma, betta, vars_, means_, batch_size, - config_.hidden_dim, stream); - } - - inline bool use_mean() const { return config_.use_mean; } - - private: - Config config_; - T *vars_; - T *means_; -}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h b/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h deleted file mode 100644 index b917abaf0..000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h +++ /dev/null @@ -1,42 +0,0 @@ -#pragma once - -#include -#include -#include - -#include - -#include "kernels.h" - -using namespace std; - -template -class Softmax { - public: - struct Config { - size_t nhead; - Config(size_t nhead) : nhead(nhead) {} - }; - - Softmax(Config config) : config_(config) {} - - ~Softmax() {} - - void Forward(T *vals, const T *attn_mask, int batch_size, int from_len, - int to_len, cudaStream_t &stream, bool mask_future = true) { - launch_attn_softmax(vals, attn_mask, batch_size, config_.nhead, from_len, - to_len, mask_future, stream); - } - - void Backward(T *out_grad, const T *soft_out, int batch_size, int from_len, - int to_len, cudaStream_t stream) { - launch_attn_softmax_bw(out_grad, soft_out, - batch_size * config_.nhead * from_len, to_len, - stream); - } - - void reset_size(size_t nhead) { config_.nhead = nhead; } - - private: - Config config_; -}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/strided_batch_gemm.h b/colossalai/kernel/cuda_native/csrc/kernels/include/strided_batch_gemm.h deleted file mode 100644 index d386650e8..000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/strided_batch_gemm.h +++ /dev/null @@ -1,100 +0,0 @@ -/* Copyright 2021 The LightSeq Team - Copyright Microsoft DeepSpeed - This file is adapted from Microsoft DeepSpeed - Licensed under the MIT License. -*/ -#pragma once - -#include -#include -#include - -#include - -#include "cublas_wrappers.h" - -template -class StridedBatchGemm { - public: - struct Config { - int m; - int n; - int k; - float alpha; - float beta; - cublasOperation_t op_A; - cublasOperation_t op_B; - std::array gemm_algos; - - Config(float param_alpha, float param_beta, cublasOperation_t opA, - cublasOperation_t opB) - : alpha(param_alpha), - beta(param_beta), - op_A(opA), - op_B(opB), - gemm_algos(std::array({99, 99, 99})) {} - void SetConfig(int mm, int nn, int kk) { - m = mm; - n = nn; - k = kk; - } - }; - - StridedBatchGemm(const Config &config) : _config(config) {} - - virtual ~StridedBatchGemm() {} - - void Forward(int bsz, T *output, const T *_buffer_a, const T *_buffer_b, - cublasHandle_t handle) { - int stride_a = _config.m * _config.k; - int stride_b = _config.n * _config.k; - int stride_c = _config.m * _config.n; - - cublas_strided_batched_gemm( - handle, _config.m, _config.n, _config.k, &_config.alpha, &_config.beta, - _buffer_a, _buffer_b, output, _config.op_A, _config.op_B, stride_a, - stride_b, stride_c, bsz, cublasGemmAlgo_t(_config.gemm_algos[0])); - } - - void Backward(int bsz, const T *d_output, const T *_buffer_a, - const T *_buffer_b, cublasHandle_t handle, - T *inpGradA = nullptr, T *inpGradB = nullptr) { - int mb = (_config.op_A == CUBLAS_OP_T ? _config.k : _config.m); - int kb = (_config.op_A == CUBLAS_OP_T ? _config.m : _config.k); - - int stride_a = mb * _config.n; - int stride_b = _config.n * kb; - int stride_c = _config.m * _config.k; - - // B need to transpose. - cublasOperation_t op_b = - (_config.op_B == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T); - - // Calculate d_A. - cublas_strided_batched_gemm( - handle, mb, kb, _config.n, &_config.alpha, &_config.beta, - (_config.op_A == CUBLAS_OP_T ? _buffer_b : d_output), - (_config.op_A == CUBLAS_OP_T ? d_output : _buffer_b), inpGradA, - CUBLAS_OP_N, op_b, stride_a, stride_b, stride_c, bsz, - cublasGemmAlgo_t(_config.gemm_algos[1])); - - // A need to transpose. - cublasOperation_t op_a = - (_config.op_A == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T); - - stride_a = _config.m * _config.k; - stride_b = _config.m * _config.n; - stride_c = _config.n * _config.k; - - // Calculate d_B. - cublas_strided_batched_gemm( - handle, _config.k, _config.n, _config.m, &_config.alpha, &_config.beta, - _buffer_a, d_output, inpGradB, op_a, CUBLAS_OP_N, stride_a, stride_b, - stride_c, bsz, cublasGemmAlgo_t(_config.gemm_algos[2])); - } - - inline void SetConfig(int m, int n, int k) { _config.SetConfig(m, n, k); } - - private: - Config _config; -}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu deleted file mode 100644 index e2f1869b1..000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu +++ /dev/null @@ -1,1172 +0,0 @@ -#include - -#include "block_reduce.h" -#include "kernels.h" - -namespace cg = cooperative_groups; -const float LN_EPSILON = 1e-8f; -#define TILE_DIM 32 - -template -__forceinline__ __device__ T add_eps(T x) { - return fabsf(x) > LN_EPSILON ? x : (x < 0 ? -LN_EPSILON : LN_EPSILON); -} - -/** -@brief: ker_layer_norm -Standard layer normalization. -It will not only output the layer norm result, - but also outputs variance. - may also output means, depends on whether - the means argument is nullptr - -@thread -gridDim.x = batch_size * seq_len -blockDim.x = hidden_size - -@param -ln_res: [batch_size* seq_len, hidden_size], ln result. -vars: [batch_size* seq_len], variance per token -means: [batch_size* seq_len], means per token, can be nullput -inp: [batch_size * seq_len, hidden_size], ln input. -scale: [hidden_size], ln scale -bias: [hidden_size], ln bias -*/ -template -__global__ void ker_layer_norm(T *ln_res, T *vars, T *means, const T *inp, - const T *scale, const T *bias, int hidden_size) { - // step 0. compute local sum - float l_sum = 0; - float l_square_sum = 0; - const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size; - for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float4 val = inp_f4[idx]; - l_sum += val.x + val.y + val.z + val.w; - l_square_sum += - val.x * val.x + val.y * val.y + val.z * val.z + val.w * val.w; - } - - // step 1. compute reduce sum - float mean_dim = float(hidden_size) * 4.f; - float reduce_val[2] = {l_sum, l_square_sum}; - blockReduce(reduce_val); - __shared__ float s_mean, s_var; - if (threadIdx.x == 0) { - s_mean = reduce_val[0] / mean_dim; - if (means != nullptr) { - means[blockIdx.x] = s_mean; - } - s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; - vars[blockIdx.x] = s_var; - s_var = rsqrtf(s_var); - } - __syncthreads(); - - // step 2. layer norm result - float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size; - for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float4 vscale = __ldg((const float4 *)scale + idx); - float4 vbias = __ldg((const float4 *)bias + idx); - float4 val = inp_f4[idx]; - val.x = (val.x - s_mean) * s_var * vscale.x + vbias.x; - val.y = (val.y - s_mean) * s_var * vscale.y + vbias.y; - val.z = (val.z - s_mean) * s_var * vscale.z + vbias.z; - val.w = (val.w - s_mean) * s_var * vscale.w + vbias.w; - output_f4[idx] = val; - } -} - -template <> -__global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars, - __half *means, const __half *inp, - const __half *scale, const __half *bias, - int hidden_size) { - // step 0. compute local sum - float l_sum = 0; - float l_square_sum = 0; - const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size; - for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float4 val_f4 = inp_f4[idx]; - __half2 *val_h2 = (__half2 *)(&val_f4); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 val_f2 = __half22float2(val_h2[i]); - l_sum += val_f2.x + val_f2.y; - l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y; - } - } - - // step 1. compute reduce sum - float mean_dim = float(hidden_size) * 8.f; - float reduce_val[2] = {l_sum, l_square_sum}; - blockReduce(reduce_val); - __shared__ float s_mean, s_var; - if (threadIdx.x == 0) { - s_mean = reduce_val[0] / mean_dim; - if (means != nullptr) { - means[blockIdx.x] = s_mean; - } - s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; - vars[blockIdx.x] = s_var; - s_var = rsqrtf(s_var); - } - __syncthreads(); - - // step 2. layer norm result - float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size; - for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - // load scale, bias, input - float4 scale_f4 = __ldg((const float4 *)scale + idx); - __half2 *scale_h2 = (__half2 *)(&scale_f4); - float4 bias_f4 = __ldg((const float4 *)bias + idx); - __half2 *bias_h2 = (__half2 *)(&bias_f4); - float4 val_f4 = inp_f4[idx]; - __half2 *val_h2 = (__half2 *)(&val_f4); - -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 scale_f2 = __half22float2(scale_h2[i]); - float2 bias_f2 = __half22float2(bias_h2[i]); - float2 val_f2 = __half22float2(val_h2[i]); - val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; - val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; - val_h2[i] = __float22half2_rn(val_f2); - } - output_f4[idx] = val_f4; - } -} - -// __global__ void ker_layer_norm_x2(__half *ln_res, __half *vars, -// __half *means, const __half *inp, -// const __half *scale, const __half -// *bias, int hidden_size) { -// // step 0. compute local sum -// float l_sum = 0; -// float l_square_sum = 0; -// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * 2 * hidden_size; -// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * -// 2) { -// float4 val_f4 = inp_f4[idx]; -// float4 val_f4_1 = inp_f4[idx+1]; -// __half2 *val_h2 = (__half2 *)(&val_f4); -// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); -// #pragma unroll -// for (int i = 0; i < 4; i++) { -// float2 val_f2 = __half22float2(val_h2[i]); -// float2 val_f2_1 = __half22float2(val_h2_1[i]); -// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y; -// l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y + val_f2_1.x -// * val_f2_1.x + val_f2_1.y * val_f2_1.y; -// } -// } - -// // step 1. compute reduce sum -// float mean_dim = float(hidden_size) * 8.f * 2; -// float reduce_val[2] = {l_sum, l_square_sum}; -// blockReduce(reduce_val); -// __shared__ float s_mean, s_var; -// if (threadIdx.x == 0) { -// s_mean = reduce_val[0] / mean_dim; -// if (means != nullptr) { -// means[blockIdx.x] = s_mean; -// } -// s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; -// vars[blockIdx.x] = s_var; -// s_var = rsqrtf(s_var); -// } -// __syncthreads(); - -// // step 2. layer norm result -// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 2; -// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * -// 2) { -// // load scale, bias, input -// float4 scale_f4 = __ldg((const float4 *)scale + idx); -// __half2 *scale_h2 = (__half2 *)(&scale_f4); -// float4 scale_f4_1 = __ldg((const float4 *)scale + idx + 1); -// __half2 *scale_h2_1 = (__half2 *)(&scale_f4_1); -// float4 bias_f4 = __ldg((const float4 *)bias + idx); -// __half2 *bias_h2 = (__half2 *)(&bias_f4); -// float4 bias_f4_1 = __ldg((const float4 *)bias + idx + 1); -// __half2 *bias_h2_1 = (__half2 *)(&bias_f4_1); -// float4 val_f4 = inp_f4[idx]; -// __half2 *val_h2 = (__half2 *)(&val_f4); -// float4 val_f4_1 = inp_f4[idx+1]; -// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); - -// #pragma unroll -// for (int i = 0; i < 4; i++) { -// float2 scale_f2 = __half22float2(scale_h2[i]); -// float2 scale_f2_1 = __half22float2(scale_h2_1[i]); -// float2 bias_f2 = __half22float2(bias_h2[i]); -// float2 bias_f2_1 = __half22float2(bias_h2_1[i]); -// float2 val_f2 = __half22float2(val_h2[i]); -// float2 val_f2_1 = __half22float2(val_h2_1[i]); -// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; -// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; -// val_h2[i] = __float22half2_rn(val_f2); -// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x + -// bias_f2_1.x; val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y -// + bias_f2_1.y; val_h2_1[i] = __float22half2_rn(val_f2_1); -// } -// output_f4[idx] = val_f4; -// output_f4[idx+1] = val_f4_1; -// } -// } - -// __global__ void ker_layer_norm_x4(__half *ln_res, __half *vars, -// __half *means, const __half *inp, -// const __half *scale, const __half -// *bias, int hidden_size) { -// // step 0. compute local sum -// float l_sum = 0; -// float l_square_sum = 0; -// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size * 4; -// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * -// 4) { -// float4 val_f4 = inp_f4[idx]; -// float4 val_f4_1 = inp_f4[idx+1]; -// float4 val_f4_2 = inp_f4[idx+2]; -// float4 val_f4_3 = inp_f4[idx+3]; -// __half2 *val_h2 = (__half2 *)(&val_f4); -// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); -// __half2 *val_h2_2 = (__half2 *)(&val_f4_2); -// __half2 *val_h2_3 = (__half2 *)(&val_f4_3); -// #pragma unroll -// for (int i = 0; i < 4; i++) { -// float2 val_f2 = __half22float2(val_h2[i]); -// float2 val_f2_1 = __half22float2(val_h2_1[i]); -// float2 val_f2_2 = __half22float2(val_h2_2[i]); -// float2 val_f2_3 = __half22float2(val_h2_3[i]); -// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y + val_f2_2.x + -// val_f2_2.y + val_f2_3.x + val_f2_3.y; l_square_sum += val_f2.x * -// val_f2.x + val_f2.y * val_f2.y; l_square_sum += val_f2_1.x * val_f2_1.x -// + val_f2_1.y * val_f2_1.y; l_square_sum += val_f2_2.x * val_f2_2.x + -// val_f2_2.y * val_f2_2.y; l_square_sum += val_f2_3.x * val_f2_3.x + -// val_f2_3.y * val_f2_3.y; -// } -// } - -// // step 1. compute reduce sum -// float mean_dim = float(hidden_size) * 8.f * 4; -// float reduce_val[2] = {l_sum, l_square_sum}; -// blockReduce(reduce_val); -// __shared__ float s_mean, s_var; -// if (threadIdx.x == 0) { -// s_mean = reduce_val[0] / mean_dim; -// if (means != nullptr) { -// means[blockIdx.x] = s_mean; -// } -// s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; -// vars[blockIdx.x] = s_var; -// s_var = rsqrtf(s_var); -// } -// __syncthreads(); - -// // step 2. layer norm result -// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 4; -// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * -// 4) { -// // load scale, bias, input -// float4 scale_f4 = __ldg((const float4 *)scale + idx); -// __half2 *scale_h2 = (__half2 *)(&scale_f4); -// float4 scale_f4_1 = __ldg((const float4 *)scale + idx + 1); -// __half2 *scale_h2_1 = (__half2 *)(&scale_f4_1); -// float4 scale_f4_2 = __ldg((const float4 *)scale + idx + 2); -// __half2 *scale_h2_2 = (__half2 *)(&scale_f4_2); -// float4 scale_f4_3 = __ldg((const float4 *)scale + idx + 3); -// __half2 *scale_h2_3 = (__half2 *)(&scale_f4_3); -// float4 bias_f4 = __ldg((const float4 *)bias + idx); -// __half2 *bias_h2 = (__half2 *)(&bias_f4); -// float4 bias_f4_1 = __ldg((const float4 *)bias + idx + 1); -// __half2 *bias_h2_1 = (__half2 *)(&bias_f4_1); -// float4 bias_f4_2 = __ldg((const float4 *)bias + idx + 2); -// __half2 *bias_h2_2 = (__half2 *)(&bias_f4_2); -// float4 bias_f4_3 = __ldg((const float4 *)bias + idx + 3); -// __half2 *bias_h2_3 = (__half2 *)(&bias_f4_3); -// float4 val_f4 = inp_f4[idx]; -// __half2 *val_h2 = (__half2 *)(&val_f4); -// float4 val_f4_1 = inp_f4[idx+1]; -// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); -// float4 val_f4_2 = inp_f4[idx+2]; -// __half2 *val_h2_2 = (__half2 *)(&val_f4_2); -// float4 val_f4_3 = inp_f4[idx+3]; -// __half2 *val_h2_3 = (__half2 *)(&val_f4_3); - -// #pragma unroll -// for (int i = 0; i < 4; i++) { -// float2 scale_f2 = __half22float2(scale_h2[i]); -// float2 scale_f2_1 = __half22float2(scale_h2_1[i]); -// float2 scale_f2_2 = __half22float2(scale_h2_2[i]); -// float2 scale_f2_3 = __half22float2(scale_h2_3[i]); -// float2 bias_f2 = __half22float2(bias_h2[i]); -// float2 bias_f2_1 = __half22float2(bias_h2_1[i]); -// float2 bias_f2_2 = __half22float2(bias_h2_2[i]); -// float2 bias_f2_3 = __half22float2(bias_h2_3[i]); -// float2 val_f2 = __half22float2(val_h2[i]); -// float2 val_f2_1 = __half22float2(val_h2_1[i]); -// float2 val_f2_2 = __half22float2(val_h2_2[i]); -// float2 val_f2_3 = __half22float2(val_h2_3[i]); -// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; -// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; -// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x + -// bias_f2_1.x; val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y -// + bias_f2_1.y; val_f2_2.x = (val_f2_2.x - s_mean) * s_var * -// scale_f2_2.x + bias_f2_2.x; val_f2_2.y = (val_f2_2.y - s_mean) * s_var -// * scale_f2_2.y + bias_f2_2.y; val_f2_3.x = (val_f2_3.x - s_mean) * -// s_var * scale_f2_3.x + bias_f2_3.x; val_f2_3.y = (val_f2_3.y - s_mean) -// * s_var * scale_f2_3.y + bias_f2_3.y; val_h2[i] = -// __float22half2_rn(val_f2); val_h2_1[i] = __float22half2_rn(val_f2_1); -// val_h2_2[i] = __float22half2_rn(val_f2_2); -// val_h2_3[i] = __float22half2_rn(val_f2_3); -// } -// output_f4[idx] = val_f4; -// output_f4[idx+1] = val_f4_1; -// output_f4[idx+2] = val_f4_2; -// output_f4[idx+3] = val_f4_3; -// } -// } - -template <> -void launch_layer_norm(float *ln_res, float *vars, float *means, - const float *inp, const float *scale, - const float *bias, int batch_size, int hidden_dim, - cudaStream_t stream) { - if (hidden_dim % 4 != 0) { - throw std::runtime_error("violate hidden_dim % 4 = 0"); - } - hidden_dim >>= 2; - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - dim3 grid_dim(batch_size); - dim3 block_dim(nthread); - - ker_layer_norm<<>>( - ln_res, vars, means, inp, scale, bias, hidden_dim); -} - -template <> -void launch_layer_norm<__half>(__half *ln_res, __half *vars, __half *means, - const __half *inp, const __half *scale, - const __half *bias, int batch_size, - int hidden_dim, cudaStream_t stream) { - if (hidden_dim % 8 != 0) { - throw std::runtime_error("violate hidden_dim % 8 = 0"); - } - hidden_dim >>= 3; - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - dim3 grid_dim(batch_size); - dim3 block_dim(nthread); - - ker_layer_norm<__half><<>>( - ln_res, vars, means, inp, scale, bias, hidden_dim); - // if (hidden_dim % 8 != 0) { - // throw std::runtime_error("violate hidden_dim % 8 = 0"); - // } - // hidden_dim >>= 3; - - // if (hidden_dim * 8 < 8192) { - // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - // dim3 grid_dim(batch_size); - // dim3 block_dim(nthread); - // ker_layer_norm<__half><<>>( - // ln_res, vars, means, inp, scale, bias, hidden_dim); - // } else if (hidden_dim * 8 >= 8192 && hidden_dim * 8 <= 8192 * 2) { - // hidden_dim >>= 1; - // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - // dim3 grid_dim(batch_size); - // dim3 block_dim(nthread); - // ker_layer_norm_x2<<>>( - // ln_res, vars, means, inp, scale, bias, hidden_dim); - // } else if (hidden_dim * 8 > 8192 * 2 && hidden_dim * 8 <= 8192 * 4) { - // hidden_dim >>= 2; - // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - // dim3 grid_dim(batch_size); - // dim3 block_dim(nthread); - // ker_layer_norm_x4<<>>( - // ln_res, vars, means, inp, scale, bias, hidden_dim); - // } else { - // throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 32768"); - // } -} - -/** -@brief: ker_ln_bw_dgamma_dbetta -Layer norm backword kernel, compute the gradient of gamma and betta. -dbetta = sum(dout, dim=0) -dgamma = sum(xhat * dout, dim=0) -xhat = (input - mean) * rsqrt(var) or - (output - betta) / gamma - - -@thread -gridDim.x = hidden_size / 32 -blockDim.x = 32 -blockDim.y = 32 - -@param -gamma_grad: [hidden_size], gradient of gamma -betta_grad: [hidden_size], gradient of betta -out_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output -inp_or_out: [batch_size * seq_len, hidden_size], ln output if means is nullptr - ln input if means is not nullptr -gamma: [hidden_size], gamma of ln, - used to compute xhat, maybe nullptr -betta: [hidden_size], betta of ln, - used to compute xhat, maybe nullptr -vars: [batch_size * seq_len], variance of ln forward, - used to compute xhat, maybe nullptr -means: [batch_size * seq_len], mean of ln forward, - used to compute xhat, maybe nullptr -(gamma && betta) ^ (vars && means) should be true -*/ -template -__global__ void ker_ln_bw_dgamma_dbetta(T *gamma_grad, T *betta_grad, - const T *out_grad, const T *inp_or_out, - const T *gamma, const T *betta, - const T *vars, const T *means, int rows, - int width) { - __shared__ float betta_buffer[TILE_DIM][TILE_DIM]; - __shared__ float gamma_buffer[TILE_DIM][TILE_DIM]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int idx = blockDim.x * blockIdx.x + threadIdx.x; - int offset = threadIdx.y * width + idx; - int y_stride = width * TILE_DIM; - - // Loop across inp height - float dbetta = 0; - float dgamma = 0; - float dout, val; - if (idx < width) { - if (means == nullptr) { - float vbetta = (float)betta[idx]; - float vgamma = (float)gamma[idx]; - for (int r = threadIdx.y; r < rows; r += TILE_DIM) { - dout = (float)out_grad[offset]; - // inp_or_out is output - val = (float)inp_or_out[offset]; - dbetta += dout; - dgamma += ((val - vbetta) / add_eps(vgamma) * dout); - offset += y_stride; - } - } else { - for (int r = threadIdx.y; r < rows; r += TILE_DIM) { - dout = (float)out_grad[offset]; - // inp_or_out is input - val = (float)inp_or_out[offset]; - dbetta += dout; - dgamma += ((val - (float)means[r]) * - rsqrtf((float)vars[r] + LN_EPSILON) * dout); - offset += y_stride; - } - } - } - - // Sum the shared buffer. - betta_buffer[threadIdx.x][threadIdx.y] = dbetta; - gamma_buffer[threadIdx.x][threadIdx.y] = dgamma; - __syncthreads(); - float s1 = betta_buffer[threadIdx.y][threadIdx.x]; - float s2 = gamma_buffer[threadIdx.y][threadIdx.x]; - __syncthreads(); - - for (int i = 1; i < TILE_DIM; i <<= 1) { - s1 += g.shfl_down(s1, i); - s2 += g.shfl_down(s2, i); - } - - int pos = blockIdx.x * TILE_DIM + threadIdx.y; - if (threadIdx.x == 0 && idx < width) { - betta_grad[pos] = s1; - gamma_grad[pos] = s2; - } -} - -/** -@brief: ker_ln_bw_dinp -Layer norm backword kernel, compute the gradient of input. -dinp = (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / hidden_dim) - * rsqrt(var) -xhat = (input - mean) * rsqrt(var) if mean is not nullptr - (output - betta) / gamma if mean is nullptr -dxhat = dout * gamma - - -@thread -gridDim.x = batch_size * seq_len -blockDim.x = hidden_size - -@param -inp_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output -out_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output -residual_grad: [batch_size * seq_len, hidden_size], gradient of residual input, - usually appear in pre-layer-norm for transformer layer, maybe nullptr -inp_or_out: [batch_size * seq_len, hidden_size], ln output if means is nullptr - ln input if means is not nullptr -gamma: [hidden_size], gamma of ln, - used to compute xhat and dxhat -betta: [hidden_size], betta of ln, - used to compute xhat, maybe nullptr -vars: [batch_size * seq_len], variance of ln forward, - used to compute xhat and dinp -means: [batch_size * seq_len], mean of ln forward, - used to compute xhat, maybe nullptr -*/ -template -__global__ void ker_ln_bw_dinp(T *inp_grad, const T *out_grad, - const T *residual_grad, const T *inp_or_out, - const T *gamma, const T *betta, const T *vars, - const T *means, int hidden_dim) { - int offset = blockIdx.x * hidden_dim + threadIdx.x; - float4 dxhat, xhat; - float var_rsqrt; - - if (threadIdx.x < hidden_dim) { - // step 0. dxhat = dout * gamma - dxhat = ((const float4 *)out_grad)[offset]; - float4 vgamma = ((const float4 *)gamma)[threadIdx.x]; - dxhat.x *= vgamma.x; - dxhat.y *= vgamma.y; - dxhat.z *= vgamma.z; - dxhat.w *= vgamma.w; - - /* - step 1. xhat = (output - betta) / gamma or - (input - mean) * rsqrtf(var) - */ - xhat = ((const float4 *)inp_or_out)[offset]; - var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); - if (means == nullptr) { - // inp_or_out is output, xhat = (output - betta) / gamma - float4 vbetta = ((const float4 *)betta)[threadIdx.x]; - xhat.x = (xhat.x - vbetta.x) / add_eps(vgamma.x); - xhat.y = (xhat.y - vbetta.y) / add_eps(vgamma.y); - xhat.z = (xhat.z - vbetta.z) / add_eps(vgamma.z); - xhat.w = (xhat.w - vbetta.w) / add_eps(vgamma.w); - } else { - // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) - float fmean = (float)means[blockIdx.x]; - xhat.x = (xhat.x - fmean) * var_rsqrt; - xhat.y = (xhat.y - fmean) * var_rsqrt; - xhat.z = (xhat.z - fmean) * var_rsqrt; - xhat.w = (xhat.w - fmean) * var_rsqrt; - } - } - - /* step2. block reduce sum for dxhat and dxhat*xhat */ - float reduce_val[2] = {0.f, 0.f}; - if (threadIdx.x < hidden_dim) { - reduce_val[0] = dxhat.x + dxhat.y + dxhat.z + dxhat.w; - reduce_val[1] = dxhat.x * xhat.x + dxhat.y * xhat.y + dxhat.z * xhat.z + - dxhat.w * xhat.w; - } - blockReduce(reduce_val); - __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; - if (threadIdx.x == 0) { - float mean_dim = hidden_dim * 4; - s_sum_dxhat = reduce_val[0] / mean_dim; - s_sum_dxhat_xhat = reduce_val[1] / mean_dim; - } - __syncthreads(); - - /* - step3. compute input gradient - (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) - */ - if (threadIdx.x >= hidden_dim) { - return; - } - dxhat.x = (dxhat.x - s_sum_dxhat - xhat.x * s_sum_dxhat_xhat) * var_rsqrt; - dxhat.y = (dxhat.y - s_sum_dxhat - xhat.y * s_sum_dxhat_xhat) * var_rsqrt; - dxhat.z = (dxhat.z - s_sum_dxhat - xhat.z * s_sum_dxhat_xhat) * var_rsqrt; - dxhat.w = (dxhat.w - s_sum_dxhat - xhat.w * s_sum_dxhat_xhat) * var_rsqrt; - if (residual_grad) { - // Add the residual grad, - // usually in pre-layer-norm for transformer layer - float4 dresidual = ((const float4 *)residual_grad)[offset]; - dxhat.x += dresidual.x; - dxhat.y += dresidual.y; - dxhat.z += dresidual.z; - dxhat.w += dresidual.w; - } - ((float4 *)inp_grad)[offset] = dxhat; -} - -template <> -__global__ void ker_ln_bw_dinp<__half>(__half *inp_grad, const __half *out_grad, - const __half *residual_grad, - const __half *inp_or_out, - const __half *gamma, const __half *betta, - const __half *vars, const __half *means, - int hidden_dim) { - int offset = blockIdx.x * hidden_dim + threadIdx.x; - - float2 dxhat[4], xhat[4]; - float var_rsqrt; - float4 vtmp; - __half2 *tmp_h2; - float reduce_val[2] = {0.f, 0.f}; - - if (threadIdx.x < hidden_dim) { - // step 0. dxhat = dout * gamma - vtmp = ((const float4 *)out_grad)[offset]; - tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); - float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x]; - __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vdout = __half22float2(tmp_h2[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - dxhat[i].x = vdout.x * vgamma.x; - dxhat[i].y = vdout.y * vgamma.y; - reduce_val[0] += dxhat[i].x + dxhat[i].y; - } - - /* - step 1. xhat = (output - betta) / gamma or - (input - mean) * rsqrtf(var) - */ - vtmp = ((const float4 *)inp_or_out)[offset]; - var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); - if (means == nullptr) { - // inp_or_out is output, xhat = (output - betta) / gamma - float4 vbetta = ((const float4 *)betta)[threadIdx.x]; - __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vout = __half22float2(tmp_h2[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - float2 vbetta = __half22float2(betta_h2[i]); - xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); - xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - } - } else { - // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) - float fmean = (float)means[blockIdx.x]; -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vinp = __half22float2(tmp_h2[i]); - xhat[i].x = (vinp.x - fmean) * var_rsqrt; - xhat[i].y = (vinp.y - fmean) * var_rsqrt; - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - } - } - } - - /* step2. block reduce sum for dxhat and dxhat*xhat */ - blockReduce(reduce_val); - __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; - if (threadIdx.x == 0) { - float mean_dim = hidden_dim * 8; - s_sum_dxhat = reduce_val[0] / mean_dim; - s_sum_dxhat_xhat = reduce_val[1] / mean_dim; - } - __syncthreads(); - - /* - step3. compute input gradient - (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) - */ - if (threadIdx.x >= hidden_dim) { - return; - } - if (residual_grad) { - // Add the residual grad, - // usually in pre-layer-norm for transformer layer - float4 dresidual = ((const float4 *)residual_grad)[offset]; - __half *hdres = reinterpret_cast<__half *>(&dresidual); -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i])); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i + 1])); - } - } else { -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - } - } - ((float4 *)inp_grad)[offset] = vtmp; -} - -__global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad, - const __half *residual_grad, - const __half *inp_or_out, const __half *gamma, - const __half *betta, const __half *vars, - const __half *means, int hidden_dim) { - int offset = blockIdx.x * hidden_dim * 2 + threadIdx.x * 2; - - float2 dxhat[4], xhat[4]; - float2 dxhat_1[4], xhat_1[4]; - float var_rsqrt; - float4 vtmp, vtmp_1; - __half2 *tmp_h2; - __half2 *tmp_h2_1; - float reduce_val[2] = {0.f, 0.f}; - - if (threadIdx.x < hidden_dim) { - // step 0. dxhat = dout * gamma - vtmp = ((const float4 *)out_grad)[offset]; - vtmp_1 = ((const float4 *)out_grad)[offset + 1]; - tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); - tmp_h2_1 = reinterpret_cast<__half2 *>(&vtmp_1); - float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x * 2]; - float4 gamma_f4_1 = ((const float4 *)gamma)[threadIdx.x * 2 + 1]; - __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); - __half2 *gamma_h2_1 = reinterpret_cast<__half2 *>(&gamma_f4_1); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vdout = __half22float2(tmp_h2[i]); - float2 vdout_1 = __half22float2(tmp_h2_1[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - float2 vgamma_1 = __half22float2(gamma_h2_1[i]); - dxhat[i].x = vdout.x * vgamma.x; - dxhat[i].y = vdout.y * vgamma.y; - dxhat_1[i].x = vdout_1.x * vgamma_1.x; - dxhat_1[i].y = vdout_1.y * vgamma_1.y; - reduce_val[0] += dxhat[i].x + dxhat[i].y + dxhat_1[i].x + dxhat_1[i].y; - } - - /* - step 1. xhat = (output - betta) / gamma or - (input - mean) * rsqrtf(var) - */ - vtmp = ((const float4 *)inp_or_out)[offset]; - vtmp_1 = ((const float4 *)inp_or_out)[offset + 1]; - var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); - if (means == nullptr) { - // inp_or_out is output, xhat = (output - betta) / gamma - float4 vbetta = ((const float4 *)betta)[2 * threadIdx.x]; - float4 vbetta_1 = ((const float4 *)betta)[2 * threadIdx.x + 1]; - __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); - __half2 *betta_h2_1 = reinterpret_cast<__half2 *>(&vbetta_1); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vout = __half22float2(tmp_h2[i]); - float2 vout_1 = __half22float2(tmp_h2_1[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - float2 vgamma_1 = __half22float2(gamma_h2_1[i]); - float2 vbetta = __half22float2(betta_h2[i]); - float2 vbetta_1 = __half22float2(betta_h2_1[i]); - xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); - xhat_1[i].x = (vout_1.x - vbetta_1.x) / add_eps(vgamma_1.x); - xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); - xhat_1[i].y = (vout_1.y - vbetta_1.y) / add_eps(vgamma_1.y); - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - reduce_val[1] += - xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; - } - } else { - // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) - float fmean = (float)means[blockIdx.x]; -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vinp = __half22float2(tmp_h2[i]); - float2 vinp_1 = __half22float2(tmp_h2_1[i]); - xhat[i].x = (vinp.x - fmean) * var_rsqrt; - xhat_1[i].x = (vinp_1.x - fmean) * var_rsqrt; - xhat[i].y = (vinp.y - fmean) * var_rsqrt; - xhat_1[i].y = (vinp_1.y - fmean) * var_rsqrt; - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - reduce_val[1] += - xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; - } - } - } - - /* step2. block reduce sum for dxhat and dxhat*xhat */ - blockReduce(reduce_val); - __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; - if (threadIdx.x == 0) { - float mean_dim = hidden_dim * 8 * 2; - s_sum_dxhat = reduce_val[0] / mean_dim; - s_sum_dxhat_xhat = reduce_val[1] / mean_dim; - } - __syncthreads(); - - /* - step3. compute input gradient - (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) - */ - if (threadIdx.x >= hidden_dim) { - return; - } - if (residual_grad) { - // Add the residual grad, - // usually in pre-layer-norm for transformer layer - float4 dresidual = ((const float4 *)residual_grad)[offset]; - float4 dresidual_1 = ((const float4 *)residual_grad)[offset + 1]; - __half *hdres = reinterpret_cast<__half *>(&dresidual); - __half *hdres_1 = reinterpret_cast<__half *>(&dresidual_1); -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i])); - tmp_h2_1[i].x = __float2half( - (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i])); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i + 1])); - tmp_h2_1[i].y = __float2half( - (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i + 1])); - } - } else { -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_1[i].x = __float2half( - (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_1[i].y = __float2half( - (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - } - } - ((float4 *)inp_grad)[offset] = vtmp; - ((float4 *)inp_grad)[offset + 1] = vtmp_1; -} - -__global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad, - const __half *residual_grad, - const __half *inp_or_out, const __half *gamma, - const __half *betta, const __half *vars, - const __half *means, int hidden_dim) { - int offset = blockIdx.x * hidden_dim * 4 + threadIdx.x * 4; - - float2 dxhat[4], xhat[4]; - float2 dxhat_1[4], xhat_1[4]; - float2 dxhat_2[4], xhat_2[4]; - float2 dxhat_3[4], xhat_3[4]; - float var_rsqrt; - float4 vtmp, vtmp_1, vtmp_2, vtmp_3; - __half2 *tmp_h2; - __half2 *tmp_h2_1; - __half2 *tmp_h2_2; - __half2 *tmp_h2_3; - float reduce_val[2] = {0.f, 0.f}; - - if (threadIdx.x < hidden_dim) { - // step 0. dxhat = dout * gamma - vtmp = ((const float4 *)out_grad)[offset]; - vtmp_1 = ((const float4 *)out_grad)[offset + 1]; - vtmp_2 = ((const float4 *)out_grad)[offset + 2]; - vtmp_3 = ((const float4 *)out_grad)[offset + 3]; - tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); - tmp_h2_1 = reinterpret_cast<__half2 *>(&vtmp_1); - tmp_h2_2 = reinterpret_cast<__half2 *>(&vtmp_2); - tmp_h2_3 = reinterpret_cast<__half2 *>(&vtmp_3); - float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x * 4]; - float4 gamma_f4_1 = ((const float4 *)gamma)[threadIdx.x * 4 + 1]; - float4 gamma_f4_2 = ((const float4 *)gamma)[threadIdx.x * 4 + 2]; - float4 gamma_f4_3 = ((const float4 *)gamma)[threadIdx.x * 4 + 3]; - __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); - __half2 *gamma_h2_1 = reinterpret_cast<__half2 *>(&gamma_f4_1); - __half2 *gamma_h2_2 = reinterpret_cast<__half2 *>(&gamma_f4_2); - __half2 *gamma_h2_3 = reinterpret_cast<__half2 *>(&gamma_f4_3); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vdout = __half22float2(tmp_h2[i]); - float2 vdout_1 = __half22float2(tmp_h2_1[i]); - float2 vdout_2 = __half22float2(tmp_h2_2[i]); - float2 vdout_3 = __half22float2(tmp_h2_3[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - float2 vgamma_1 = __half22float2(gamma_h2_1[i]); - float2 vgamma_2 = __half22float2(gamma_h2_2[i]); - float2 vgamma_3 = __half22float2(gamma_h2_3[i]); - dxhat[i].x = vdout.x * vgamma.x; - dxhat[i].y = vdout.y * vgamma.y; - dxhat_1[i].x = vdout_1.x * vgamma_1.x; - dxhat_1[i].y = vdout_1.y * vgamma_1.y; - dxhat_2[i].x = vdout_2.x * vgamma_2.x; - dxhat_2[i].y = vdout_2.y * vgamma_2.y; - dxhat_3[i].x = vdout_3.x * vgamma_3.x; - dxhat_3[i].y = vdout_3.y * vgamma_3.y; - reduce_val[0] += dxhat[i].x + dxhat[i].y + dxhat_1[i].x + dxhat_1[i].y + - dxhat_2[i].x + dxhat_2[i].y + dxhat_3[i].x + - dxhat_3[i].y; - } - - /* - step 1. xhat = (output - betta) / gamma or - (input - mean) * rsqrtf(var) - */ - vtmp = ((const float4 *)inp_or_out)[offset]; - vtmp_1 = ((const float4 *)inp_or_out)[offset + 1]; - vtmp_2 = ((const float4 *)inp_or_out)[offset + 2]; - vtmp_3 = ((const float4 *)inp_or_out)[offset + 3]; - var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); - if (means == nullptr) { - // inp_or_out is output, xhat = (output - betta) / gamma - float4 vbetta = ((const float4 *)betta)[4 * threadIdx.x]; - float4 vbetta_1 = ((const float4 *)betta)[4 * threadIdx.x + 1]; - float4 vbetta_2 = ((const float4 *)betta)[4 * threadIdx.x + 2]; - float4 vbetta_3 = ((const float4 *)betta)[4 * threadIdx.x + 3]; - __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); - __half2 *betta_h2_1 = reinterpret_cast<__half2 *>(&vbetta_1); - __half2 *betta_h2_2 = reinterpret_cast<__half2 *>(&vbetta_2); - __half2 *betta_h2_3 = reinterpret_cast<__half2 *>(&vbetta_3); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vout = __half22float2(tmp_h2[i]); - float2 vout_1 = __half22float2(tmp_h2_1[i]); - float2 vout_2 = __half22float2(tmp_h2_2[i]); - float2 vout_3 = __half22float2(tmp_h2_3[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - float2 vgamma_1 = __half22float2(gamma_h2_1[i]); - float2 vgamma_2 = __half22float2(gamma_h2_2[i]); - float2 vgamma_3 = __half22float2(gamma_h2_3[i]); - float2 vbetta = __half22float2(betta_h2[i]); - float2 vbetta_1 = __half22float2(betta_h2_1[i]); - float2 vbetta_2 = __half22float2(betta_h2_2[i]); - float2 vbetta_3 = __half22float2(betta_h2_3[i]); - xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); - xhat_1[i].x = (vout_1.x - vbetta_1.x) / add_eps(vgamma_1.x); - xhat_2[i].x = (vout_2.x - vbetta_2.x) / add_eps(vgamma_2.x); - xhat_3[i].x = (vout_3.x - vbetta_3.x) / add_eps(vgamma_3.x); - xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); - xhat_1[i].y = (vout_1.y - vbetta_1.y) / add_eps(vgamma_1.y); - xhat_2[i].y = (vout_2.y - vbetta_2.y) / add_eps(vgamma_2.y); - xhat_3[i].y = (vout_3.y - vbetta_3.y) / add_eps(vgamma_3.y); - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - reduce_val[1] += - xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; - reduce_val[1] += - xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y; - reduce_val[1] += - xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y; - } - } else { - // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) - float fmean = (float)means[blockIdx.x]; -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vinp = __half22float2(tmp_h2[i]); - float2 vinp_1 = __half22float2(tmp_h2_1[i]); - float2 vinp_2 = __half22float2(tmp_h2_2[i]); - float2 vinp_3 = __half22float2(tmp_h2_3[i]); - xhat[i].x = (vinp.x - fmean) * var_rsqrt; - xhat_1[i].x = (vinp_1.x - fmean) * var_rsqrt; - xhat_2[i].x = (vinp_2.x - fmean) * var_rsqrt; - xhat_3[i].x = (vinp_3.x - fmean) * var_rsqrt; - xhat[i].y = (vinp.y - fmean) * var_rsqrt; - xhat_1[i].y = (vinp_1.y - fmean) * var_rsqrt; - xhat_2[i].y = (vinp_2.y - fmean) * var_rsqrt; - xhat_3[i].y = (vinp_3.y - fmean) * var_rsqrt; - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - reduce_val[1] += - xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; - reduce_val[1] += - xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y; - reduce_val[1] += - xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y; - } - } - } - - /* step2. block reduce sum for dxhat and dxhat*xhat */ - blockReduce(reduce_val); - __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; - if (threadIdx.x == 0) { - float mean_dim = hidden_dim * 8 * 4; - s_sum_dxhat = reduce_val[0] / mean_dim; - s_sum_dxhat_xhat = reduce_val[1] / mean_dim; - } - __syncthreads(); - - /* - step3. compute input gradient - (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) - */ - if (threadIdx.x >= hidden_dim) { - return; - } - if (residual_grad) { - // Add the residual grad, - // usually in pre-layer-norm for transformer layer - float4 dresidual = ((const float4 *)residual_grad)[offset]; - float4 dresidual_1 = ((const float4 *)residual_grad)[offset + 1]; - float4 dresidual_2 = ((const float4 *)residual_grad)[offset + 2]; - float4 dresidual_3 = ((const float4 *)residual_grad)[offset + 3]; - __half *hdres = reinterpret_cast<__half *>(&dresidual); - __half *hdres_1 = reinterpret_cast<__half *>(&dresidual_1); - __half *hdres_2 = reinterpret_cast<__half *>(&dresidual_2); - __half *hdres_3 = reinterpret_cast<__half *>(&dresidual_3); -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i])); - tmp_h2_1[i].x = __float2half( - (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i])); - tmp_h2_2[i].x = __float2half( - (dxhat_2[i].x - s_sum_dxhat - xhat_2[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_2[2 * i])); - tmp_h2_3[i].x = __float2half( - (dxhat_3[i].x - s_sum_dxhat - xhat_3[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_3[2 * i])); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i + 1])); - tmp_h2_1[i].y = __float2half( - (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i + 1])); - tmp_h2_2[i].y = __float2half( - (dxhat_2[i].y - s_sum_dxhat - xhat_2[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i + 1])); - tmp_h2_3[i].y = __float2half( - (dxhat_3[i].y - s_sum_dxhat - xhat_3[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i + 1])); - } - } else { -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_1[i].x = __float2half( - (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_2[i].x = __float2half( - (dxhat_2[i].x - s_sum_dxhat - xhat_2[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_3[i].x = __float2half( - (dxhat_3[i].x - s_sum_dxhat - xhat_3[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_1[i].y = __float2half( - (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_2[i].y = __float2half( - (dxhat_2[i].y - s_sum_dxhat - xhat_2[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_3[i].y = __float2half( - (dxhat_3[i].y - s_sum_dxhat - xhat_3[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - } - } - ((float4 *)inp_grad)[offset] = vtmp; - ((float4 *)inp_grad)[offset + 1] = vtmp_1; - ((float4 *)inp_grad)[offset + 2] = vtmp_2; - ((float4 *)inp_grad)[offset + 3] = vtmp_3; -} - -/** -Layer norm backword, - compute the gradient of gamma, betta and input. -dbetta = sum(dout, dim=0) -xhat = (input - mean) * rsqrt(var) if mean is not nullptr - (output - betta) / gamma if mean is nullptr -dgamma = sum(xhat * dout, dim=0) -dxhat = dout * gamma -dinp = (dxhat - (sum(dxhat, 1) + xhat * sum(dxhat * xhat, 1)) / hidden_dim) - * rsqrt(var) - -residual_grad, means, betta can be nullptr. -residual_grad will be added to dinp if it is not nullptr - which is useful in transformer layer when pre-ln -means and betta are only used to compute xhat, - (means == nullptr) ^ (betta == nullptr) should be true -*/ -template <> -void launch_ln_bw(float *gamma_grad, float *betta_grad, float *inp_grad, - const float *out_grad, const float *residual_grad, - const float *inp_or_out, const float *gamma, - const float *betta, const float *vars, - const float *means, int batch, int hidden_dim, - cudaStream_t stream[2]) { - // compute grad of gamma and betta - dim3 grid_dim(((hidden_dim + TILE_DIM - 1) / TILE_DIM) * TILE_DIM); - dim3 block_dim(TILE_DIM, TILE_DIM); - ker_ln_bw_dgamma_dbetta<<>>( - gamma_grad, betta_grad, out_grad, inp_or_out, gamma, betta, vars, means, - batch, hidden_dim); - - // compute grad of input - if (hidden_dim % 4 != 0 || hidden_dim > 4096) { - throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 4096"); - } - hidden_dim >>= 2; - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - ker_ln_bw_dinp<<>>( - inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, means, - hidden_dim); -} - -template <> -void launch_ln_bw<__half>(__half *gamma_grad, __half *betta_grad, - __half *inp_grad, const __half *out_grad, - const __half *residual_grad, const __half *inp_or_out, - const __half *gamma, const __half *betta, - const __half *vars, const __half *means, int batch, - int hidden_dim, cudaStream_t stream[2]) { - // compute grad of gamma and betta - dim3 grid_dim(((hidden_dim + TILE_DIM - 1) / TILE_DIM) * TILE_DIM); - dim3 block_dim(TILE_DIM, TILE_DIM); - ker_ln_bw_dgamma_dbetta<__half><<>>( - gamma_grad, betta_grad, out_grad, inp_or_out, gamma, betta, vars, means, - batch, hidden_dim); - - // compute grad of input - if (hidden_dim % 8 != 0) { - throw std::runtime_error("hidden_dim % 8 != 0"); - } - hidden_dim >>= 3; - - if (hidden_dim * 8 <= 8192) { - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - ker_ln_bw_dinp<<>>( - inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, - means, hidden_dim); - } else if (hidden_dim * 8 > 8192 && hidden_dim * 8 <= 8192 * 2) { - hidden_dim >>= 1; - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - ker_ln_bw_dinp_x2<<>>( - inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, - means, hidden_dim); - } else if (hidden_dim * 8 > 2 * 8192 && hidden_dim * 8 <= 8192 * 4) { - hidden_dim >>= 2; - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - ker_ln_bw_dinp_x4<<>>( - inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, - means, hidden_dim); - } else { - throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 32768"); - } -} diff --git a/colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu deleted file mode 100644 index 3862a699d..000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu +++ /dev/null @@ -1,365 +0,0 @@ -#include -#include - -#include -#include - -#include "block_reduce.h" -#include "kernels.h" - -namespace cg = cooperative_groups; -const float EPSILON = 1e-8f; - -/** -@brief: softmax_kernel -Softmax forward kernel for - enc-self-attn, dec-self-attn, encdec-attn - -@thread -gridDim.x = dynamic -gridDim.y = batch_size -gridDim.z = nhead -blockDim.x = from_len - -@param -inp: [batch_size, nhead, from_len, to_len], softmax input. -attn_mask: [batch_size, to_len], padding tokens are -inf, - non padding tokens are 0. - attn_mask!=nullptr for enc-self-attn and enc-dec-attn - attn_mask=nullptr and mask_future=ture for dec-self-attn training - attn_mask=nullptr and mask_future=false for dec-self-attn infer -*/ -template -__global__ void ker_attn_softmax(T *inp, const T *attn_mask, int from_len, - int to_len, bool mask_future) { - int batch_id = blockIdx.y; - int head_id = blockIdx.z; - const int nhead = gridDim.z; - const int token_per_reduce = 1; - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - T mval[ele_per_thread]; - if (attn_mask) { - attn_mask += batch_id * to_len; - BlockLoad(ts_load).Load(attn_mask, mval, to_len, REDUCE_FLOAT_INF_NEG); - } - - inp += flat_3dim(batch_id, head_id, 0, nhead, from_len * to_len); - for (int token_id = blockIdx.x * token_per_reduce; token_id < from_len; - token_id += gridDim.x * token_per_reduce) { - T inp_val[token_per_reduce][ele_per_thread]; - for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { - BlockLoad(ts_load).Load(inp + (token_id + i) * to_len, inp_val[i], to_len, - REDUCE_FLOAT_INF_NEG); - } - - /* step 1. compute max */ - // thread local max - float val[token_per_reduce][ele_per_thread]; - float l_max[token_per_reduce]; - for (int i = 0; i < token_per_reduce; i++) { - l_max[i] = REDUCE_FLOAT_INF_NEG; - for (int j = 0; j < ele_per_thread; j++) { - if (attn_mask) { - val[i][j] = (float)inp_val[i][j] + (float)mval[j]; - } else { - if (mask_future && ele_per_thread * threadIdx.x + j > token_id + i) { - val[i][j] = REDUCE_FLOAT_INF_NEG; - } else { - val[i][j] = (float)inp_val[i][j]; - } - } - l_max[i] = fmaxf(l_max[i], val[i][j]); - } - } - // block reduce max - blockReduce(l_max); - // write shared - __shared__ float s_max[token_per_reduce]; - if (threadIdx.x == 0) { - for (int i = 0; i < token_per_reduce; i++) { - s_max[i] = l_max[i]; - } - } - __syncthreads(); - - /* step 2. compute sum */ - // thread local sum - float l_sum[token_per_reduce]; - for (int i = 0; i < token_per_reduce; i++) { - l_sum[i] = 0.f; - for (int j = 0; j < ele_per_thread; j++) { - val[i][j] = __expf(val[i][j] - s_max[i]); - l_sum[i] += val[i][j]; - } - } - // block reduce sum - blockReduce(l_sum); - // write shared - __shared__ float s_sum[token_per_reduce]; - if (threadIdx.x == 0) { - for (int i = 0; i < token_per_reduce; i++) { - s_sum[i] = __fdividef(1.0f, l_sum[i] + EPSILON); - } - } - __syncthreads(); - - /* step 3. compute final result */ - for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { - for (int j = 0; j < ele_per_thread; j++) { - inp_val[i][j] = (T)(val[i][j] * s_sum[i]); - } - BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i], - to_len); - } - } // blockIdx.x -} - -template -__global__ void ker_attn_softmax_lt32(T *inp, const T *attn_mask, int from_len, - int to_len, bool mask_future) { - int batch_id = blockIdx.y; - int head_id = blockIdx.z; - const int nhead = gridDim.z; - const int token_per_reduce = 1; - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - T mval[ele_per_thread]; - if (attn_mask) { - attn_mask += batch_id * to_len; - BlockLoad(ts_load).Load(attn_mask, mval, to_len, REDUCE_FLOAT_INF_NEG); - } - - inp += flat_3dim(batch_id, head_id, 0, nhead, from_len * to_len); - for (int token_id = blockIdx.x * token_per_reduce; token_id < from_len; - token_id += gridDim.x * token_per_reduce) { - T inp_val[token_per_reduce][ele_per_thread]; - for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { - BlockLoad(ts_load).Load(inp + (token_id + i) * to_len, inp_val[i], to_len, - REDUCE_FLOAT_INF_NEG); - } - - /* step 1. compute max */ - // thread local max - float val[token_per_reduce][ele_per_thread]; - float l_max[token_per_reduce]; - for (int i = 0; i < token_per_reduce; i++) { - l_max[i] = REDUCE_FLOAT_INF_NEG; - for (int j = 0; j < ele_per_thread; j++) { - if (attn_mask) { - val[i][j] = (float)inp_val[i][j] + (float)mval[j]; - } else { - if (mask_future && ele_per_thread * threadIdx.x + j > token_id + i) { - val[i][j] = REDUCE_FLOAT_INF_NEG; - } else { - val[i][j] = (float)inp_val[i][j]; - } - } - l_max[i] = fmaxf(l_max[i], val[i][j]); - } - } - // warp reduce max - warpReduce(l_max); - - /* step 2. compute sum */ - // thread local sum - float l_sum[token_per_reduce]; - for (int i = 0; i < token_per_reduce; i++) { - l_sum[i] = 0.f; - for (int j = 0; j < ele_per_thread; j++) { - val[i][j] = __expf(val[i][j] - l_max[i]); - l_sum[i] += val[i][j]; - } - } - // warp reduce sum - warpReduce(l_sum); - - /* step 3. compute final result */ - for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { - l_sum[i] = __fdividef(1.0f, l_sum[i] + EPSILON); - for (int j = 0; j < ele_per_thread; j++) { - inp_val[i][j] = (T)(val[i][j] * l_sum[i]); - } - BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i], - to_len); - } - } // blockIdx.x -} - -/* - attn_mask!=nullptr for enc-self-attn and enc-dec-attn - attn_mask=nullptr and mask_future=ture for dec-self-attn training - attn_mask=nullptr and mask_future=false for dec-self-attn infer -*/ -template <> -void launch_attn_softmax(float *inp, const float *attn_mask, - int batch_size, int nhead, int from_len, - int to_len, bool mask_future, - cudaStream_t stream) { - dim3 grid_dim(1, batch_size, nhead); - if (to_len <= 32) { - ker_attn_softmax_lt32<<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 64) { - ker_attn_softmax_lt32<<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 128) { - grid_dim.x = 16; - ker_attn_softmax<<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 256) { - grid_dim.x = 32; - ker_attn_softmax<<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 512) { - grid_dim.x = 64; - ker_attn_softmax<<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else { - throw std::runtime_error( - "Sequence length greater than 512 is currently not supported"); - } -} - -template <> -void launch_attn_softmax<__half>(__half *inp, const __half *attn_mask, - int batch_size, int nhead, int from_len, - int to_len, bool mask_future, - cudaStream_t stream) { - dim3 grid_dim(1, batch_size, nhead); - if (to_len <= 32) { - ker_attn_softmax_lt32<__half, 32, 1><<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 64) { - ker_attn_softmax_lt32<__half, 32, 2><<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 128) { - grid_dim.x = 8; - ker_attn_softmax<__half, 64, 2><<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 256) { - grid_dim.x = 16; - ker_attn_softmax<__half, 128, 2><<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 512) { - grid_dim.x = 32; - ker_attn_softmax<__half, 256, 2><<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else { - throw std::runtime_error( - "Sequence length greater than 512 is currently not supported"); - } -} - -/** -@brief: ker_attn_softmax_bw -Softmax backward in self attention. - -@thread -gridDim.x = batch_size * nhead * seq_len / warps_per_block -blockDim.x = WARP_SIZE -blockDim.y = warps_per_block - -@param -grad: [batch_size, nhead, seq_len, seq_len], output grad. -output: [batch_size, nhead, seq_len, seq_len], output of softmax forward. -*/ -template -__global__ void ker_attn_softmax_bw(T *grad, const T *inp, int softmax_length) { - int batch_idx = blockIdx.x * blockDim.y + threadIdx.y; - int offset = batch_idx * softmax_length + threadIdx.x; - - grad += offset; - inp += offset; - - T grad_reg[ITERATIONS]; - T inp_reg[ITERATIONS]; - float sum = 0.0; - -#pragma unroll - for (int i = 0; i < ITERATIONS; ++i) { - int curr_idx = threadIdx.x + i * WARP_SIZE; - if (curr_idx < softmax_length) { - grad_reg[i] = grad[i * WARP_SIZE]; - inp_reg[i] = inp[i * WARP_SIZE]; - sum += (float)grad_reg[i] * (float)inp_reg[i]; - } - } - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i); - -#pragma unroll - for (int i = 0; i < ITERATIONS; ++i) { - int curr_idx = threadIdx.x + i * WARP_SIZE; - if (curr_idx < softmax_length) - grad[i * WARP_SIZE] = (T)((float)inp_reg[i] * ((float)grad_reg[i] - sum)); - } -} - -template -void launch_attn_softmax_bw(T *out_grad, const T *soft_inp, int rows, - int softmax_len, cudaStream_t stream) { - const int warps_per_block = 4; - // rows = batch_size * nhead * from_len - dim3 grid_dim(rows / warps_per_block); - dim3 block_dim(WARP_SIZE, warps_per_block); - - if (softmax_len <= 32) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 64) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 128) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 256) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 384) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 512) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 768) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 1024) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 2048) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else - throw std::runtime_error( - std::string( - "Special sequence length found in softmax backward, seq_len: ") + - std::to_string(softmax_len)); -} - -template void launch_attn_softmax_bw<__half>(__half *out_grad, - const __half *soft_inp, int rows, - int softmax_len, - cudaStream_t stream); -template void launch_attn_softmax_bw(float *out_grad, - const float *soft_inp, int rows, - int softmax_len, - cudaStream_t stream); diff --git a/colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu deleted file mode 100644 index 04de3c092..000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu +++ /dev/null @@ -1,314 +0,0 @@ -#include -#include -#include - -#include "kernels.h" - -using namespace cub; - -/** -@brief: transform_0213 -Split the attention heads and reshape input -during backward progress of encoder self-attention - -@thread -gridDim.x = batch_size -gridDim.y = seq_len -blockDim.x = min(hidden_dim, MAX_THREADS) - -@param -input: [batch_size, seq_len, hidden_dim] -output: [batch_size, nhead, seq_len, head_dim] -batch_size: the size of the current batch -seq_len: the sequence length of the current batch -hidden_dim: dim of the hidden tensor -nhead: number of attention heads -*/ - -template -__global__ void transform_0213(T *output, const T *input, int hidden_dim, - int head_dim); - -template <> -__global__ void transform_0213(float *output, const float *input, - int hidden_dim, int head_dim) { - int batch_id = blockIdx.x; - int token_id = blockIdx.y; - int seq_len = gridDim.y; - int nhead = hidden_dim / head_dim; - - // [b, s, h] - int src_offset = flat_3dim(batch_id, token_id, 0, seq_len, hidden_dim); - // [b, nh, s, ad] - int trg_offset = - flat_4dim(batch_id, 0, token_id, 0, nhead, seq_len, head_dim); - - const float4 *input4 = reinterpret_cast(input); - float4 *res4 = reinterpret_cast(output); - float4 vinput4; - - for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { - vinput4 = input4[src_offset + i]; - - int head_id = i / head_dim; - int dim_id = i % head_dim; - int cur_trg_offset = flat_3dim(head_id, 0, dim_id, seq_len, head_dim); - res4[trg_offset + cur_trg_offset] = vinput4; - } -} - -template <> -__global__ void transform_0213<__half>(__half *output, const __half *input, - int hidden_dim, int head_dim) { - int batch_id = blockIdx.x; - int token_id = blockIdx.y; - int seq_len = gridDim.y; - int nhead = hidden_dim / head_dim; - - // [b, s, h] - int src_offset = flat_3dim(batch_id, token_id, 0, seq_len, hidden_dim); - // [b, nh, s, ad] - int trg_offset = - flat_4dim(batch_id, 0, token_id, 0, nhead, seq_len, head_dim); - - const float4 *input4 = reinterpret_cast(input); - float4 *res4 = reinterpret_cast(output); - float4 vinput4; - - for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { - vinput4 = input4[src_offset + i]; - - int head_id = i / head_dim; - int dim_id = i % head_dim; - int cur_trg_offset = flat_3dim(head_id, 0, dim_id, seq_len, head_dim); - res4[trg_offset + cur_trg_offset] = vinput4; - } -} - -// [b, s, h] -> [b, nh, s, ad] -template <> -void launch_transform_0213(float *output, const float *input, - int batch_size, int seq_len, int hidden_dim, - int nhead, cudaStream_t stream) { - hidden_dim >>= 2; - int head_dim = hidden_dim / nhead; - - dim3 grid_dim(batch_size, seq_len); - dim3 block_dim(min(hidden_dim, MAX_THREADS)); - - transform_0213 - <<>>(output, input, hidden_dim, head_dim); -} - -template <> -void launch_transform_0213<__half>(__half *output, const __half *input, - int batch_size, int seq_len, int hidden_dim, - int nhead, cudaStream_t stream) { - hidden_dim >>= 3; - int head_dim = hidden_dim / nhead; - - dim3 grid_dim(batch_size, seq_len); - dim3 block_dim(min(hidden_dim, MAX_THREADS)); - - transform_0213<__half> - <<>>(output, input, hidden_dim, head_dim); -} - -/** -@brief: bias_add_transform_20314 -Add bias to input, transform from -[0, 1, 2, 3, 4] to [2, 0, 3, 1, 4] - -@thread -gridDim.x = dim_0 -gridDim.y = dim_1 -gridDim.z = dim_2 -blockDim.x = min(dim_3 * dim_4, MAX_THREADS) - -@param -input: [dim_0, dim_1, dim_2, dim_3, dim_4] -bias: [dim_2, dim_3, dim_4] -output: [dim_2, dim_0, dim_3, dim_1, dim_4] -*/ -template -__global__ void bias_add_transform_20314(T *output, const T *input, - const T *bias, int dim_3, int dim_4); - -template <> -__global__ void bias_add_transform_20314(float *output, - const float *input, - const float *bias, int dim_3, - int dim_4) { - int id0 = blockIdx.x; - int id1 = blockIdx.y; - int id2 = blockIdx.z; - int dim_0 = gridDim.x; - int dim_1 = gridDim.y; - int dim_2 = gridDim.z; - int dim_34 = dim_3 * dim_4; - - int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34); - int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4); - int bias_offset = flat_2dim(id2, 0, dim_34); - - const float4 *qkv4 = reinterpret_cast(input); - const float4 *bias4 = reinterpret_cast(bias); - float4 *res4 = reinterpret_cast(output); - float4 vqkv4; - float4 vbias4; - float4 vres4; - - for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) { - vqkv4 = qkv4[src_offset + i]; - vbias4 = bias4[bias_offset + i]; - vres4.x = vqkv4.x + vbias4.x; - vres4.y = vqkv4.y + vbias4.y; - vres4.z = vqkv4.z + vbias4.z; - vres4.w = vqkv4.w + vbias4.w; - - int id3 = i / dim_4; - int id4 = i % dim_4; - int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4); - res4[trg_offset + cur_trg_offset] = vres4; - } -} - -template <> -__global__ void bias_add_transform_20314<__half>(__half *output, - const __half *input, - const __half *bias, int dim_3, - int dim_4) { - int id0 = blockIdx.x; - int id1 = blockIdx.y; - int id2 = blockIdx.z; - int dim_0 = gridDim.x; - int dim_1 = gridDim.y; - int dim_2 = gridDim.z; - int dim_34 = dim_3 * dim_4; - - int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34); - int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4); - int bias_offset = flat_2dim(id2, 0, dim_34); - - const float4 *qkv4 = reinterpret_cast(input); - const float4 *bias4 = reinterpret_cast(bias); - float4 *res4 = reinterpret_cast(output); - float4 vqkv4; - float4 vbias4; - float4 vres4; - __half2 *h2_qkv = reinterpret_cast<__half2 *>(&vqkv4); - __half2 *h2_bias = reinterpret_cast<__half2 *>(&vbias4); - __half2 *h2_res = reinterpret_cast<__half2 *>(&vres4); - - for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) { - vqkv4 = qkv4[src_offset + i]; - vbias4 = bias4[bias_offset + i]; - h2_res[0] = __hadd2(h2_qkv[0], h2_bias[0]); - h2_res[1] = __hadd2(h2_qkv[1], h2_bias[1]); - h2_res[2] = __hadd2(h2_qkv[2], h2_bias[2]); - h2_res[3] = __hadd2(h2_qkv[3], h2_bias[3]); - - int id3 = i / dim_4; - int id4 = i % dim_4; - int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4); - res4[trg_offset + cur_trg_offset] = vres4; - } -} - -// [b, s, 3, h] -> [3, b, nh, s, ad] -template <> -void launch_bias_add_transform_20314(float *output, const float *input, - const float *bias, int dim_0, - int dim_1, int dim_2, int dim_3, - int dim_4, cudaStream_t stream) { - dim_4 >>= 2; - - dim3 grid_dim(dim_0, dim_1, dim_2); - dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS)); - - bias_add_transform_20314 - <<>>(output, input, bias, dim_3, dim_4); -} - -template <> -void launch_bias_add_transform_20314<__half>(__half *output, - const __half *input, - const __half *bias, int dim_0, - int dim_1, int dim_2, int dim_3, - int dim_4, cudaStream_t stream) { - dim_4 >>= 3; - - dim3 grid_dim(dim_0, dim_1, dim_2); - dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS)); - - bias_add_transform_20314<__half> - <<>>(output, input, bias, dim_3, dim_4); -} - -/** -@brief: transform4d_0213 -Reshape the input matrix to merge the heads - -@thread -gridDim.x = (num_all + max_block_thread - 1) / max_block_thread -blockDim.x = max_block_thread - -@param -input: [trans_count, batch_size, nhead, seq_len, head_dim] -output: [batch_size, seq_len, trans_count, nhead, head_dim] -batch_size: the size of the current batch -seq_len: the sequence length of the current batch -hidden_dim: dim of the hidden tensor -nhead: number of attention heads -trans_count: 1 or 3, the count of matrice need to be transformed -*/ -template -__global__ void transform4d_0213(T *output, const T *input, int batch_size, - int seq_len, int trans_count, int nhead, - int head_dim, int num_all) { - int offset = blockIdx.x * blockDim.x + threadIdx.x; - if (offset >= num_all) { - return; - } - int trans_id, batch_id, head_id, token_id, dim_id; - decompose_5dim(offset, batch_size, nhead, seq_len, head_dim, &trans_id, - &batch_id, &head_id, &token_id, &dim_id); - // [b, s, tc, nh, ad] - int trg_offset = flat_5dim(batch_id, token_id, trans_id, head_id, dim_id, - seq_len, trans_count, nhead, head_dim); - - const float4 *input4 = reinterpret_cast(input); - float4 *res4 = reinterpret_cast(output); - res4[trg_offset] = input4[offset]; -} - -// [tc, b, nh, s, ad] -> [b, s, tc, nh, ad] -template <> -void launch_transform4d_0213(float *output, const float *input, - int batch_size, int seq_len, int hidden_dim, - int nhead, int trans_count, - cudaStream_t stream) { - hidden_dim >>= 2; - int head_dim = hidden_dim / nhead; - int num_all = batch_size * seq_len * trans_count * hidden_dim; - int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS; - - transform4d_0213<<>>( - output, input, batch_size, seq_len, trans_count, nhead, head_dim, - num_all); -} - -template <> -void launch_transform4d_0213<__half>(__half *output, const __half *input, - int batch_size, int seq_len, - int hidden_dim, int nhead, int trans_count, - cudaStream_t stream) { - hidden_dim >>= 3; - int head_dim = hidden_dim / nhead; - int num_all = batch_size * seq_len * trans_count * hidden_dim; - int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS; - - transform4d_0213<__half><<>>( - output, input, batch_size, seq_len, trans_count, nhead, head_dim, - num_all); -} diff --git a/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp b/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp deleted file mode 100644 index d08f3dbc7..000000000 --- a/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp +++ /dev/null @@ -1,406 +0,0 @@ -#include "multihead_attention_1d.h" - -#include -#include -#include - -#if TORCH_VERSION_MAJOR > 1 || \ - (TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13) -#include -#else -#include -#endif -#include - -#include "context.h" -#include "kernels.h" - -template -MultiHeadAttention::MultiHeadAttention(int layer_id, int max_batch_tokens, - int max_seq_len, int hidden_size, - int num_heads, - float attn_prob_dropout_ratio, - float hidden_output_dropout_ratio, - bool pre_or_postLayerNorm) - : _layer_id(layer_id), - _max_batch_tokens(max_batch_tokens), - _max_seq_len(max_seq_len), - _hidden_size(hidden_size), - _heads(num_heads), - _training(true), - _pre_or_postLayerNorm(pre_or_postLayerNorm), - _qkv_linear( - typename FeedForward::Config(3 * hidden_size, hidden_size)), - _attn_out_linear( - typename FeedForward::Config(hidden_size, hidden_size)), - _attn_ln(typename Normalize_Layer::Config(hidden_size, false), - _max_batch_tokens), - _softmax(typename Softmax::Config(num_heads)), - _attn_prob_dropout(typename Dropout::Config(attn_prob_dropout_ratio), - _max_batch_tokens * _heads * _max_seq_len), - _attn_dropout(typename Dropout::Config(hidden_output_dropout_ratio), - _max_batch_tokens * _hidden_size), - _attn_scores(typename StridedBatchGemm::Config( - (T(1.0) / T(sqrt(_hidden_size / _heads))), T(0.0), CUBLAS_OP_T, - CUBLAS_OP_N)), - _attn_context(typename StridedBatchGemm::Config( - T(1.0), T(0.0), CUBLAS_OP_N, CUBLAS_OP_N)) { - assert(_hidden_size % _heads == 0); -} - -template -MultiHeadAttention::~MultiHeadAttention() { - free_mem_buffer(); -} - -template -void MultiHeadAttention::attn_layer_fw(const T *input_ptr, - const T *input_mask_ptr, - T *output_ptr, T *buffer) { - T *q_tf_ptr = _qkv_ptr; - T *k_tf_ptr = q_tf_ptr + _batch_dim / pg_size; - T *v_tf_ptr = k_tf_ptr + _batch_dim / pg_size; - - if (_pre_or_postLayerNorm) { - _attn_ln.Forward(_gemmQKV_inp_ptr, input_ptr, _attn_nw_ptr, _attn_nb_ptr, - _batch_tokens, _stream); - } - const T *gemmQKV_inp_ptr = - _pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr; - _qkv_linear.reset_size(3 * _hidden_size / pg_size, _hidden_size); - _qkv_linear.Forward(_batch_tokens, gemmQKV_inp_ptr, _attn_qkvw_ptr, buffer, - _cublasHandle); - - launch_bias_add_transform_20314(q_tf_ptr, buffer, _attn_qkvb_ptr, - _batch_size, _seq_len, 3, _heads / pg_size, - _hidden_size / _heads, _stream); - - // attention scores, q*k - _attn_scores.Forward(_batch_heads, _soft_out_ptr, k_tf_ptr, q_tf_ptr, - _cublasHandle); - - // Softmax + Mask - _softmax.reset_size(_heads / pg_size); - _softmax.Forward(_soft_out_ptr, input_mask_ptr, _batch_size, _seq_len, - _seq_len, _stream, true); - - // attn prob dropout. - _attn_prob_dropout.dropout(_ctx_bufB_ptr, _soft_out_ptr, - _batch_heads * _seq_len * _seq_len, _stream); - - // attention context, score * v - _attn_context.Forward(_batch_heads, buffer, v_tf_ptr, _ctx_bufB_ptr, - _cublasHandle); - - // [b, nh, s, ad] -> [b, s, nh, ad] - launch_transform4d_0213(_attn_o_inp_ptr, buffer, _batch_size, _seq_len, - _hidden_size / pg_size, _heads / pg_size, 1, - _stream); - - _attn_out_linear.reset_size(_hidden_size, _hidden_size / pg_size); - _attn_out_linear.Forward(_batch_tokens, _attn_o_inp_ptr, _attn_ow_ptr, - output_ptr, _cublasHandle); - - // allreduce - if (pg == c10::detail::UniqueVoidPtr() || pg->getSize() == 1) { - } else { - auto data_type = torch::kFloat; - if (typeid(T) != typeid(float)) { - data_type = torch::kHalf; - } - auto output_tensor = torch::from_blob( - output_ptr, {int(_batch_size), int(_seq_len), int(_hidden_size)}, - torch::TensorOptions(torch::kCUDA).dtype(data_type)); - std::vector allreduce_tensors = {output_tensor}; - auto work = pg->allreduce(allreduce_tensors, c10d::AllreduceOptions()); - work->wait(); - } - - _attn_dropout.bias_dropout_residual(output_ptr, output_ptr, input_ptr, - _attn_ob_ptr, _batch_tokens, _hidden_size, - _stream); - if (!_pre_or_postLayerNorm) { - // in-place ln since ln-input will not be used in post-ln mode - _attn_ln.Forward(output_ptr, output_ptr, _attn_nw_ptr, _attn_nb_ptr, - _batch_tokens, _stream); - } -} - -template -void MultiHeadAttention::Forward(const T *input_ptr, const T *input_mask_ptr, - T *out_ptr) { - _stream = Context::Instance().get_stream(); - _cublasHandle = Context::Instance().get_cublashandle(); - T *attn_buffer = _shared_mem_ptr; // 3 * _batch_dim - - attn_layer_fw(input_ptr, input_mask_ptr, out_ptr, attn_buffer); -} - -template -void MultiHeadAttention::attn_layer_bw(const T *input_ptr, - const T *input_mask_ptr, - const T *output_ptr, - const T *grad_output_ptr, - T *grad_input_ptr, T *buffer) { - cudaStream_t streams[2] = {_stream, _stream}; - - const T *q_tf_ptr = _qkv_ptr; - const T *k_tf_ptr = q_tf_ptr + _batch_dim / pg_size; - const T *v_tf_ptr = k_tf_ptr + _batch_dim / pg_size; - // batch_dim = batch_size * seq_len * hidden_size - // buffer size: batch_dim * 3 + max(batch_dim * 3, - // batch_size * head_num * seq_len * seq_len) - T *grad_residual_ptr = buffer; - buffer += _batch_dim; - - T *grad_input_buf_ptr = buffer; // batch_dim - T *grad_qkv_5d_ptr = buffer; // batch_dim * 3 - buffer += 3 * _batch_dim / pg_size; - - T *grad_qkv_4d_ptr = buffer; // batch_dim * 3 - T *grad_softmax_ptr = buffer; // batch_size * head_num * seq_len * seq_len - // buffer += max(3 * _batch_dim, - // batch_size * head_num * seq_len * seq_len); - - if (_pre_or_postLayerNorm) { - _attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr, - grad_output_ptr, _batch_tokens, - _hidden_size, _stream); - } else { - _attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_residual_ptr, - grad_output_ptr, nullptr, output_ptr, _attn_nw_ptr, - _attn_nb_ptr, _batch_tokens, streams); - _attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr, - grad_residual_ptr, _batch_tokens, - _hidden_size, _stream); - } - - // bw of output project - _attn_out_linear.reset_size(_hidden_size, _hidden_size / pg_size); - _attn_out_linear.Backward(_batch_tokens, grad_input_ptr, _attn_o_inp_ptr, - _attn_ow_ptr, _grad_attn_ow_ptr, _grad_attn_ob_ptr, - _cublasHandle, _stream, grad_input_buf_ptr, nullptr, - false); - launch_transform_0213(grad_input_ptr, grad_input_buf_ptr, _batch_size, - _seq_len, _hidden_size / pg_size, _heads / pg_size, - _stream); - - // bw of score * v - _attn_context.Backward( - _batch_heads, grad_input_ptr, v_tf_ptr, _ctx_bufB_ptr, _cublasHandle, - grad_qkv_5d_ptr + 2 * _batch_dim / pg_size, grad_softmax_ptr); - - _attn_prob_dropout.d_dropout(grad_softmax_ptr, - _batch_heads * _seq_len * _seq_len, _stream); - - _softmax.reset_size(_heads / pg_size); - _softmax.Backward(grad_softmax_ptr, _soft_out_ptr, _batch_size, _seq_len, - _seq_len, _stream); - - // bw of q * k - _attn_scores.Backward(_batch_heads, grad_softmax_ptr, k_tf_ptr, q_tf_ptr, - _cublasHandle, grad_qkv_5d_ptr + _batch_dim / pg_size, - grad_qkv_5d_ptr); - - // [3, b, nh, s, ad] -> [b, s, 3, h] - launch_transform4d_0213(grad_qkv_4d_ptr, grad_qkv_5d_ptr, _batch_size, - _seq_len, _hidden_size / pg_size, _heads / pg_size, - 3, _stream); - - const T *gemmQKV_inp_ptr = - _pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr; - _qkv_linear.reset_size(3 * _hidden_size / pg_size, _hidden_size); - _qkv_linear.Backward(_batch_tokens, grad_qkv_4d_ptr, gemmQKV_inp_ptr, - _attn_qkvw_ptr, _grad_attn_qkvw_ptr, _grad_attn_qkvb_ptr, - _cublasHandle, _stream, grad_input_buf_ptr, nullptr, - true); - - // allreduce - if (pg == c10::detail::UniqueVoidPtr() || pg->getSize() == 1) { - } else { - auto data_type = torch::kFloat; - if (typeid(T) != typeid(float)) { - data_type = torch::kHalf; - } - auto grad_input_tensor = - torch::from_blob(grad_input_buf_ptr, - {int(_batch_size), int(_seq_len), int(_hidden_size)}, - torch::TensorOptions(torch::kCUDA).dtype(data_type)); - std::vector allreduce_tensors = {grad_input_tensor}; - auto work = pg->allreduce(allreduce_tensors, c10d::AllreduceOptions()); - work->wait(); - } - - if (_pre_or_postLayerNorm) { - _attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_input_ptr, - grad_input_buf_ptr, grad_output_ptr, gemmQKV_inp_ptr, - _attn_nw_ptr, _attn_nb_ptr, _batch_tokens, streams); - } else { - // FIXME later - launch_fused_add2(grad_input_ptr, grad_input_buf_ptr, grad_residual_ptr, - _batch_size, _seq_len, _hidden_size, _stream); - } -} - -template -void MultiHeadAttention::Backward(const T *grad_output_ptr, - const T *input_ptr, const T *output_ptr, - const T *input_mask_ptr, - T *grad_input_ptr) { - _stream = Context::Instance().get_stream(); - _cublasHandle = Context::Instance().get_cublashandle(); - T *buffer = _shared_mem_ptr; - - /* - buffer size needed by attn bw: - 4 * _batch_dim + max(3 * _batch_dim, - _batch_size * _head_num * _seq_len * _seq_len); - */ - attn_layer_bw(input_ptr, input_mask_ptr, output_ptr, grad_output_ptr, - grad_input_ptr, buffer); -} - -template -void MultiHeadAttention::SetTrainingMode(bool training) { - // Dropout will be skipped when not in training model. - _attn_prob_dropout.SetTrainingMode(training); - _attn_dropout.SetTrainingMode(training); -} - -template -T *MultiHeadAttention::_shared_mem_ptr = nullptr; - -template class MultiHeadAttention; -template class MultiHeadAttention<__half>; - -// x is torch::Tensor -#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) \ - AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) \ - CHECK_CUDA(x); \ - CHECK_CONTIGUOUS(x) - -static std::unordered_map> s_multihead_attention; - -template -int create_multihead_attention(int layer_id, int max_batch_tokens, - int max_seq_len, int hidden_dim, int num_heads, - float attn_prob_dropout_ratio, - float hidden_dropout_ratio, - bool pre_or_postLayerNorm, - c10::intrusive_ptr pg_) { - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - Context::Instance().set_stream(stream); - auto layer = std::make_shared>( - layer_id, max_batch_tokens, max_seq_len, hidden_dim, num_heads, - attn_prob_dropout_ratio, hidden_dropout_ratio, pre_or_postLayerNorm); - - layer->SetPG(pg_); - - s_multihead_attention[layer_id] = layer; - - std::string dtype = (std::is_same::value) ? "half" : "float"; - - return 0; -} - -template -std::vector multihead_attention_fw( - int layer_id, const torch::Tensor &input, const torch::Tensor &input_mask, - const torch::Tensor &in_proj_weight, const torch::Tensor &in_proj_bias, - const torch::Tensor &out_proj_weight, const torch::Tensor &out_proj_bias, - const torch::Tensor &norm_weight, const torch::Tensor &norm_bias, - bool training_mode, bool prelayernorm) { - CHECK_INPUT(input); - CHECK_INPUT(input_mask); - - const T *input_ptr = (const T *)input.data_ptr(); - const T *input_mask_ptr = (const T *)input_mask.data_ptr(); - - auto output = torch::empty_like(input); - T *out_ptr = (T *)output.data_ptr(); - - std::shared_ptr> layer = - std::static_pointer_cast>( - s_multihead_attention[layer_id]); - layer->set_cur_batch_shape(input.size(0), input.size(1)); - layer->SetTrainingMode(training_mode); - - layer->_attn_qkvw_ptr = (const T *)in_proj_weight.data_ptr(); - layer->_attn_qkvb_ptr = (const T *)in_proj_bias.data_ptr(); - layer->_attn_ow_ptr = (const T *)out_proj_weight.data_ptr(); - layer->_attn_ob_ptr = (const T *)out_proj_bias.data_ptr(); - layer->_attn_nw_ptr = (const T *)norm_weight.data_ptr(); - layer->_attn_nb_ptr = (const T *)norm_bias.data_ptr(); - - layer->Forward(input_ptr, input_mask_ptr, out_ptr); - - return {output}; -} - -template -std::vector multihead_attention_bw( - int layer_id, const torch::Tensor &grad_dec_output, - const torch::Tensor &output, const torch::Tensor &input, - const torch::Tensor &input_mask, const torch::Tensor &in_proj_weight, - const torch::Tensor &in_proj_bias, const torch::Tensor &out_proj_weight, - const torch::Tensor &out_proj_bias, const torch::Tensor &norm_weight, - const torch::Tensor &norm_bias) { - auto g_output = grad_dec_output.contiguous(); - CHECK_INPUT(g_output); - CHECK_INPUT(output); - CHECK_INPUT(input); - CHECK_INPUT(input_mask); - - auto grad_input = torch::empty_like(input); - auto grad_in_proj_weight = torch::empty_like(in_proj_weight); - auto grad_in_proj_bias = torch::empty_like(in_proj_bias); - auto grad_out_proj_weight = torch::empty_like(out_proj_weight); - auto grad_out_proj_bias = torch::empty_like(out_proj_bias); - auto grad_norm_weight = torch::empty_like(norm_weight); - auto grad_norm_bias = torch::empty_like(norm_bias); - - // inputs. - const T *grad_dec_output_ptr = (const T *)g_output.data_ptr(); - const T *input_ptr = (const T *)input.data_ptr(); - const T *output_ptr = (const T *)output.data_ptr(); - const T *input_mask_ptr = (const T *)input_mask.data_ptr(); - - // outputs. - T *grad_input_ptr = (T *)grad_input.data_ptr(); - - std::shared_ptr> layer = - std::static_pointer_cast>( - s_multihead_attention[layer_id]); - layer->set_cur_batch_shape(g_output.size(0), g_output.size(1)); - - layer->_grad_attn_qkvw_ptr = (T *)grad_in_proj_weight.data_ptr(); - layer->_grad_attn_qkvb_ptr = (T *)grad_in_proj_bias.data_ptr(); - layer->_grad_attn_ow_ptr = (T *)grad_out_proj_weight.data_ptr(); - layer->_grad_attn_ob_ptr = (T *)grad_out_proj_bias.data_ptr(); - layer->_grad_attn_nw_ptr = (T *)grad_norm_weight.data_ptr(); - layer->_grad_attn_nb_ptr = (T *)grad_norm_bias.data_ptr(); - - layer->Backward(grad_dec_output_ptr, input_ptr, output_ptr, input_mask_ptr, - grad_input_ptr); - - return {grad_input, grad_in_proj_weight, grad_in_proj_bias, - grad_out_proj_weight, grad_out_proj_bias, grad_norm_weight, - grad_norm_bias}; -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("multihead_attention_fw_fp32", &multihead_attention_fw, - "Multi-head Attention forward with fp32 (CUDA)"); - m.def("multihead_attention_fw_fp16", &multihead_attention_fw<__half>, - "Multi-head Attention forward with fp16 (CUDA)"); - m.def("multihead_attention_bw_fp32", &multihead_attention_bw, - "Multi-head Attention backward with fp32 (CUDA)"); - m.def("multihead_attention_bw_fp16", &multihead_attention_bw<__half>, - "Multi-head Attention backward with fp16 (CUDA)"); - m.def("create_multihead_attention_fp32", &create_multihead_attention, - "Create Multi-head Attention with fp32 (CUDA)"); - m.def("create_multihead_attention_fp16", &create_multihead_attention<__half>, - "Create Multi-head Attention with fp16 (CUDA)"); -} diff --git a/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h b/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h deleted file mode 100644 index 6505eb31f..000000000 --- a/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h +++ /dev/null @@ -1,167 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include - -#if TORCH_VERSION_MAJOR > 1 || \ - (TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13) -#include -#else -#include -#endif - -#include -#include - -#include "cuda_util.h" -#include "dropout.h" -#include "feed_forward.h" -#include "normalize_layer.h" -#include "softmax.h" -#include "strided_batch_gemm.h" - -template -class MultiHeadAttention { - public: - MultiHeadAttention(int layer_id, int max_batch_tokens, int _max_seq_len, - int hidden_size, int num_heads, float attn_dropout_ratio, - float hidden_output_dropout_ratio, - bool pre_or_postLayerNorm); - - virtual ~MultiHeadAttention(); - - void Forward(const T *input_ptr, const T *input_mask_ptr, T *out_ptr); - - void Backward(const T *grad_output_ptr, const T *input_ptr, - const T *output_ptr, const T *input_mask_ptr, - T *grad_input_ptr); - - void attn_layer_fw(const T *input_ptr, const T *input_mask_ptr, T *output_ptr, - T *buffer); - - void attn_layer_bw(const T *input_ptr, const T *input_mask_ptr, - const T *output_ptr, const T *grad_output_ptr, - T *grad_input_attn_layer_bwptr, T *buffer); - - void set_cur_batch_shape(int batch_size, int seq_len) { - _batch_size = batch_size; - _seq_len = seq_len; - _batch_tokens = batch_size * seq_len; - _batch_heads = batch_size * _heads / pg_size; - _batch_dim = _batch_tokens * _hidden_size; - _attn_scores.SetConfig(_seq_len, _seq_len, _hidden_size / _heads); - _attn_context.SetConfig(_hidden_size / _heads, _seq_len, _seq_len); - } - - void SetTrainingMode(bool training); - inline bool IsTrainingMode() const { return _training; } - - void SetPG(c10::intrusive_ptr pg_) { - pg = pg_; - pg_size = 1; - if (pg != c10::detail::UniqueVoidPtr()) { - pg_size = pg->getSize(); - } - allocate_mem_buffer(); - } - - // weights ptr - const T *_attn_qkvw_ptr; - const T *_attn_qkvb_ptr; - const T *_attn_ow_ptr; - const T *_attn_ob_ptr; - const T *_attn_nw_ptr; - const T *_attn_nb_ptr; - - // grads ptr - T *_grad_attn_qkvw_ptr; - T *_grad_attn_qkvb_ptr; - T *_grad_attn_ow_ptr; - T *_grad_attn_ob_ptr; - T *_grad_attn_nw_ptr; - T *_grad_attn_nb_ptr; - - private: - void allocate_mem_buffer() { - // allocate local gpu memory - if (_pre_or_postLayerNorm) { - _gemmQKV_inp_ptr = cuda_malloc(_max_batch_tokens * _hidden_size); - } else { - _gemmQKV_inp_ptr = nullptr; - } - - _qkv_ptr = cuda_malloc(_max_batch_tokens * _hidden_size * 3); - _soft_out_ptr = - cuda_malloc(_max_batch_tokens * _heads / pg_size * _max_seq_len); - _ctx_bufB_ptr = - cuda_malloc(_max_batch_tokens * _heads / pg_size * _max_seq_len); - _attn_o_inp_ptr = cuda_malloc(_max_batch_tokens * _hidden_size); - - // buffer size needed by attn bw - size_t smem_size = - 4 * _max_batch_tokens * _hidden_size / pg_size + - std::max(3 * _max_batch_tokens * _hidden_size / pg_size, - _max_batch_tokens * _heads / pg_size * _max_seq_len); - - if (!_shared_mem_ptr) { - cuda_free(_shared_mem_ptr); - _shared_mem_ptr = cuda_malloc(smem_size); - } - } - - void free_mem_buffer() { - // free local gpu memory - cuda_free(_gemmQKV_inp_ptr); - cuda_free(_qkv_ptr); - cuda_free(_soft_out_ptr); - cuda_free(_ctx_bufB_ptr); - cuda_free(_attn_o_inp_ptr); - - // free shared gpu memory between layers - cuda_free(_shared_mem_ptr); - _shared_mem_ptr = nullptr; - } - - // const parameter between batch - const size_t _layer_id; - const size_t _hidden_size; - const size_t _heads; - const size_t _max_batch_tokens; - const size_t _max_seq_len; - const bool _pre_or_postLayerNorm; - // dynamic parameter between batch - size_t _batch_size; - size_t _seq_len; - size_t _batch_tokens; - size_t _batch_heads; - size_t _batch_dim; - bool _training; - - cublasHandle_t _cublasHandle; - cudaStream_t _stream; - - // layers - FeedForward _qkv_linear; - FeedForward _attn_out_linear; - Normalize_Layer _attn_ln; - Softmax _softmax; - Dropout _attn_prob_dropout; - Dropout _attn_dropout; - StridedBatchGemm _attn_scores; - StridedBatchGemm _attn_context; - - // local GPU memory - T *_gemmQKV_inp_ptr; - T *_qkv_ptr; - T *_soft_out_ptr; - T *_ctx_bufB_ptr; - T *_attn_o_inp_ptr; - // shared GPU memory between layer - static T *_shared_mem_ptr; - - c10::intrusive_ptr pg; - int pg_size; -}; diff --git a/colossalai/kernel/cuda_native/csrc/smoothquant/binding.cpp b/colossalai/kernel/cuda_native/csrc/smoothquant/binding.cpp deleted file mode 100644 index 844427294..000000000 --- a/colossalai/kernel/cuda_native/csrc/smoothquant/binding.cpp +++ /dev/null @@ -1,8 +0,0 @@ -#include - -#include "linear.h" - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("linear_silu_a8_w8_bfp32_ofp32", &linear_silu_a8_w8_bfp32_ofp32, - "Linear SiLU (INT8)"); -} diff --git a/colossalai/kernel/cuda_native/csrc/smoothquant/linear.cu b/colossalai/kernel/cuda_native/csrc/smoothquant/linear.cu deleted file mode 100644 index a30d02a4c..000000000 --- a/colossalai/kernel/cuda_native/csrc/smoothquant/linear.cu +++ /dev/null @@ -1,162 +0,0 @@ -// modified from https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/kernels/linear.cu - -#include "linear.h" -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -torch::Tensor linear_silu_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8 - torch::Tensor weight, // INT8 - torch::Tensor bias, // FP32 - float alpha, // FP32 - float beta // FP32 -) { - auto M = input.size(0); - auto N = weight.size(0); - auto K = input.size(1); - - using ElementOutput = float; - using ElementAccumulator = int32_t; - using ElementComputeEpilogue = float; - using ElementInputA = int8_t; // <- data type of elements in input matrix A - using ElementInputB = int8_t; // <- data type of elements in input matrix B - - // The code section below describes matrix layout of input and output - // matrices. Column Major for Matrix A, Row Major for Matrix B and Row Major - // for Matrix C - using LayoutInputA = cutlass::layout::RowMajor; - using LayoutInputB = cutlass::layout::ColumnMajor; - using LayoutOutput = cutlass::layout::RowMajor; - -#if CUDA_ARCH >= 800 - using EpilogueOp = cutlass::epilogue::thread::LinearCombinationSilu< - ElementOutput, // <- data type of output matrix - 128 / cutlass::sizeof_bits< - ElementOutput>::value, // <- this is the number of elements per - // vectorized memory access. For half - // precision, it's 8 elements. This - // becomes the vector width of math - // instructions in epilogue too - ElementAccumulator, // <- data type of accumulator - ElementComputeEpilogue // <- data type for alpha in linear combination - // function - >; - using Gemm = cutlass::gemm::device::Gemm< - int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, - ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, - cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, - cutlass::gemm::GemmShape<256, 128, 64>, - cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, - EpilogueOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -#elif CUDA_ARCH >= 750 - using EpilogueOp = cutlass::epilogue::thread::LinearCombinationSilu< - ElementOutput, // <- data type of output matrix - 128 / cutlass::sizeof_bits< - ElementOutput>::value, // <- this is the number of elements per - // vectorized memory access. For half - // precision, it's 8 elements. This - // becomes the vector width of math - // instructions in epilogue too - ElementAccumulator, // <- data type of accumulator - ElementComputeEpilogue // <- data type for alpha in linear combination - // function - >; - - using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration< - cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, - ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>; - using Gemm = cutlass::gemm::device::Gemm< - int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, - ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, - cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, - DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape, - DefaultGemmCfg::InstructionShape, - EpilogueOp>; -#elif CUDA_ARCH >= 700 - #define USE_TORCH_SILU - using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration< - cutlass::arch::OpClassSimt, cutlass::arch::Sm70, - ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>; - using Gemm = cutlass::gemm::device::Gemm< - int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, - ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, - cutlass::arch::OpClassSimt, cutlass::arch::Sm70, - DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape, - DefaultGemmCfg::InstructionShape, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 1, ElementAccumulator, ElementComputeEpilogue>>; -#else - #error "Unsupported cuda arch" -#endif - - auto input_size = cutlass::MatrixCoord(M, K); - auto weight_size = cutlass::MatrixCoord(K, N); - auto output_size = cutlass::MatrixCoord(M, N); - - auto device = input.device(); - // use the broadcasted bias as the output - auto out = bias.to(device).view({1, -1}).repeat({M, 1}); - - // constexpr int kSparse = Gemm::kSparse; - // How many elements of A are covered per ElementE - // constexpr int kElementsPerElementE = Gemm::kElementsPerElementE; - // The size of individual meta data - // constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits; - cutlass::gemm::GemmCoord problem_size(M, N, K); - - cutlass::TensorRef input_ref( - input.data_ptr(), LayoutInputA::packed(input_size)); - cutlass::TensorRef weight_ref( - weight.data_ptr(), LayoutInputB::packed(weight_size)); - cutlass::TensorRef out_ref( - out.data_ptr(), LayoutOutput::packed(output_size)); - - typename Gemm::Arguments arguments{ - problem_size, // <- problem size of matrix multiplication - input_ref, // <- reference to matrix A on device - weight_ref, // <- reference to matrix B on device - out_ref, // <- reference to matrix C on device - out_ref, // <- reference to matrix D on device - {alpha, beta}, 1}; - Gemm gemm_op; - - // Using the arguments, query for extra workspace required for matrix - // multiplication computation - size_t workspace_size = Gemm::get_workspace_size(arguments); - - // Allocate workspace memory - cutlass::device_memory::allocation workspace(workspace_size); - - // Check the problem size is supported or not - cutlass::Status status = gemm_op.can_implement(arguments); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot implement"); - } - - // Initialize CUTLASS kernel with arguments and workspace pointer - status = gemm_op.initialize(arguments, workspace.get()); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot initialize"); - } - - status = gemm_op(); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot run"); - } -#ifdef USE_TORCH_SILU -#undef USE_TORCH_SILU - out = torch::silu(out); -#endif - return out; -} diff --git a/colossalai/kernel/cuda_native/csrc/smoothquant/linear.h b/colossalai/kernel/cuda_native/csrc/smoothquant/linear.h deleted file mode 100644 index b62a27f3f..000000000 --- a/colossalai/kernel/cuda_native/csrc/smoothquant/linear.h +++ /dev/null @@ -1,12 +0,0 @@ -#include -#include - -#include -#include - -torch::Tensor linear_silu_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8 - torch::Tensor weight, // INT8 - torch::Tensor bias, // FP32 - float alpha, // FP32 - float beta // FP32 -); diff --git a/colossalai/kernel/cuda_native/mha/__init__.py b/colossalai/kernel/cuda_native/mha/__init__.py deleted file mode 100644 index cad36e598..000000000 --- a/colossalai/kernel/cuda_native/mha/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .mha import ColoAttention - -__all__ = ["ColoAttention"] diff --git a/colossalai/kernel/cuda_native/mha/flash_attn_2.py b/colossalai/kernel/cuda_native/mha/flash_attn_2.py deleted file mode 100644 index 9ee83915b..000000000 --- a/colossalai/kernel/cuda_native/mha/flash_attn_2.py +++ /dev/null @@ -1,80 +0,0 @@ -import warnings -from typing import Optional - -import torch - - -def is_ampere_or_better_gpu(): - if torch.cuda.is_available(): - device = torch.device("cuda") - properties = torch.cuda.get_device_properties(device) - if properties.major >= 8: # Ampere GPUs or newer - return True - return False - - -# "Check Ampere GPUs or newer" -HAS_FLASH_ATTN = False -if is_ampere_or_better_gpu(): - HAS_FLASH_ATTN = True -else: - warnings.warn("FlashAttention only supports Ampere GPUs or newer.") - HAS_FLASH_ATTN = False -try: - from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func - - HAS_FLASH_ATTN = True -except ImportError: - warnings.warn("please install flash_attn from https://github.com/HazyResearch/flash-attention") - HAS_FLASH_ATTN = False - -if HAS_FLASH_ATTN: - pass - - from .utils import SeqLenInfo - - def flash_attention( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - seq_len_info_q: SeqLenInfo, - seq_len_info_kv: SeqLenInfo, - bias: Optional[torch.Tensor] = None, - dropout_p: float = 0.0, - scale: float = None, - causal: bool = False, - padded: bool = False, - ): - """ - Arguments: - q: (batch, q_seqlen, nheads, headdim) - k: (batch, kv_seqlen, nheads, headdim) - v: (batch, kv_seqlen, nheads, headdim) - batch_size: int. - seq_len: int. - dropout_p: float. Dropout probability. - sm_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - Return: - attn_out: (batch, q_seqlen, nheads, headdim). - """ - if padded: - if seq_len_info_kv == None: - seq_len_info_kv = seq_len_info_q - - attn_out = flash_attn_varlen_func( - q, - k, - v, - seq_len_info_q.cu_seqlens, - seq_len_info_kv.cu_seqlens, - seq_len_info_q.max_seqlen, - seq_len_info_kv.max_seqlen, - dropout_p, - scale, - causal, - ) - else: - attn_out = flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=scale, causal=causal) - return attn_out diff --git a/colossalai/kernel/cuda_native/mha/mem_eff_attn.py b/colossalai/kernel/cuda_native/mha/mem_eff_attn.py deleted file mode 100644 index 649e74d61..000000000 --- a/colossalai/kernel/cuda_native/mha/mem_eff_attn.py +++ /dev/null @@ -1,70 +0,0 @@ -import warnings - -HAS_MEM_EFF_ATTN = False -try: - from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp, memory_efficient_attention - from xformers.ops.fmha.attn_bias import ( - BlockDiagonalCausalMask, - BlockDiagonalMask, - LowerTriangularMask, - LowerTriangularMaskWithTensorBias, - ) - - HAS_MEM_EFF_ATTN = True -except ImportError: - warnings.warn("please install xformers from https://github.com/facebookresearch/xformers") - HAS_MEM_EFF_ATTN = False - -if HAS_MEM_EFF_ATTN: - """ - A general attention module using the flash attention kernels from xformers: - https://github.com/facebookresearch/xformers/tree/main/xformers/ops/fmha - """ - from typing import Optional - - import torch - - from .utils import SeqLenInfo - - allow_alibi = True - for op in MemoryEfficientAttentionCutlassOp: - allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES) - - def mem_eff_attention( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - seq_len_info_q: SeqLenInfo, - seq_len_info_kv: SeqLenInfo, - bias: Optional[torch.Tensor] = None, - dropout_p: float = 0.0, - scale: float = None, - causal: bool = False, - padded: bool = False, - ): - attn_bias = None - if padded: # bert style - if not causal: - attn_bias = BlockDiagonalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens) - else: - attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens) - elif causal: # gpt style - attn_bias = LowerTriangularMask() - - if bias is not None: # alibi / relative position embedding - assert allow_alibi, "flash attention with bias is not supported in this system." - assert causal, "attention with bias is only supported for causal attention so far." - attn_bias = attn_bias.add_bias(bias) - - if padded: - q = q.unsqueeze(0) - k = k.unsqueeze(0) - v = v.unsqueeze(0) - - out = memory_efficient_attention(q, k, v, attn_bias=attn_bias, p=dropout_p, scale=scale) - - # shape: (b*s, n, d) - if padded: - out = out.squeeze(0) - - return out diff --git a/colossalai/kernel/cuda_native/mha/mha.py b/colossalai/kernel/cuda_native/mha/mha.py deleted file mode 100644 index 1c778439d..000000000 --- a/colossalai/kernel/cuda_native/mha/mha.py +++ /dev/null @@ -1,113 +0,0 @@ -import math -from typing import Optional - -import torch -from einops import rearrange - -from ..scaled_softmax import AttnMaskType -from .flash_attn_2 import HAS_FLASH_ATTN -from .mem_eff_attn import HAS_MEM_EFF_ATTN -from .utils import Repad, SeqLenInfo, Unpad - -if HAS_FLASH_ATTN: - from .flash_attn_2 import flash_attention -if HAS_MEM_EFF_ATTN: - from .mem_eff_attn import mem_eff_attention - - -class ColoAttention(torch.nn.Module): - def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale=None): - super().__init__() - assert ( - embed_dim % num_heads == 0 - ), f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})." - if scale is not None: - self.scale = scale - else: - self.scale = 1 / math.sqrt(embed_dim // num_heads) - self.dropout = dropout - - if not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN: - raise Exception("flash attention can not support!") - - @staticmethod - def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: - return Unpad.apply(tensor, indices) - - @staticmethod - def repad(tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor: - return Repad.apply(tensor, indices, batch_size, seq_len) - - def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - attn_mask_type: Optional[AttnMaskType] = None, - bias: Optional[torch.Tensor] = None, - ): - attn = None - if HAS_FLASH_ATTN and query.dtype in [torch.float16, torch.bfloat16] and bias == None: - attn = flash_attention - else: - attn = mem_eff_attention - - padded = attn_mask_type is not None and attn_mask_type.value % 2 == 1 - causal = attn_mask_type is not None and attn_mask_type.value > 1 - - batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1] - # unpad - seq_len_info_q = None - seq_len_info_kv = None - if padded: - # bert style, unpad process - assert ( - attn_mask is not None - ), f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}." - assert attn_mask.dim() == 2, ( - "attention mask is supposed to have shape (batch_size, seq_len), " - + f"but got {attn_mask.dim()} dimensions." - ) - - # bert style - if tgt_len == src_len: - seq_len_info_q = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device) - if batch_size > 1: - query, key, value = self.unpad( - torch.stack([query, key, value], dim=2), seq_len_info_q.indices - ).unbind(dim=1) - else: - query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1) - seq_len_info_kv = seq_len_info_q - else: - seq_len_info_q = SeqLenInfo.materialize(size=(batch_size, tgt_len), device=query.device) - seq_len_info_kv = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device) - if batch_size > 1: - query = rearrange(query, "b s ... -> c (b s) ...", c=1) - key, value = self.unpad(torch.stack([query, key, value], dim=2), seq_len_info_kv.indices).unbind( - dim=1 - ) - else: - query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1) - - out = attn( - query, - key, - value, - seq_len_info_q, - seq_len_info_kv, - dropout_p=self.dropout, - scale=self.scale, - causal=causal, - padded=padded, - ) - - # repad - if padded: - if batch_size > 1: - out = self.repad(out, seq_len_info_q.indices, batch_size, tgt_len) - out = rearrange(out, "(b s) h d -> b s h d", b=batch_size) - - out = rearrange(out, "b s h d -> b s (h d)") - return out diff --git a/colossalai/kernel/cuda_native/mha/utils.py b/colossalai/kernel/cuda_native/mha/utils.py deleted file mode 100644 index 5f01e3ef3..000000000 --- a/colossalai/kernel/cuda_native/mha/utils.py +++ /dev/null @@ -1,82 +0,0 @@ -from dataclasses import dataclass -from typing import Iterable, Tuple - -import torch -import torch.nn.functional as F -from einops import rearrange - -from colossalai.utils.device import get_current_device - - -class Unpad(torch.autograd.Function): - """ - Adapted from - https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py - """ - - @staticmethod - def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor): - ctx.save_for_backward(indices) - # [b, s, ...] - assert tensor.ndim >= 3 - ctx.bsz = tensor.shape[0] - out = rearrange(tensor, "b s ... -> (b s) ...") - ctx.shape = out.shape - # [ntokens, ...] - return out[indices] - - @staticmethod - def backward(ctx, grad_output): - (indices,) = ctx.saved_tensors - # [ntokens, ...] - grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device) - grad[indices] = grad_output - grad = rearrange(grad, "(b s) ... -> b s ...", b=ctx.bsz) - # [b, s, ...] - return grad, None - - -class Repad(torch.autograd.Function): - """ - Adapted from - https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py - """ - - @staticmethod - def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int): - ctx.save_for_backward(indices) - # [ntokens, ...] - tensor = tensor - out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device) - # [b*s, ...] - out[indices] = tensor - return out - - @staticmethod - def backward(ctx, grad_output): - (indices,) = ctx.saved_tensors - # [b*s, ...] - grad = grad_output[indices] - # [ntokens, ...] - return grad, None, None, None - - -@dataclass -class SeqLenInfo: - seqlens: Iterable[int] = None - indices: torch.Tensor = None - max_seqlen: int = None - cu_seqlens: torch.Tensor = None - - @staticmethod - def materialize(attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_current_device()): - if attn_mask is not None: - indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device) - seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten() - else: - batch_size, tgt_len = size[0], size[1] - indices = torch.arange(batch_size * tgt_len, dtype=torch.long, device=device) - seqlens = torch.LongTensor([tgt_len] * batch_size, device=device) - max_seqlen = max(seqlens) - cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device) - return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens) diff --git a/colossalai/kernel/cuda_native/multihead_attention.py b/colossalai/kernel/cuda_native/multihead_attention.py deleted file mode 100644 index 87afc1862..000000000 --- a/colossalai/kernel/cuda_native/multihead_attention.py +++ /dev/null @@ -1,338 +0,0 @@ -import math -from dataclasses import dataclass - -import torch -from torch import nn -from torch.autograd import Function - - -def check_config(config): - if config.hidden_size % config.nhead != 0: - raise Exception("hidden_size % nhead != 0") - - factor = 8 if config.fp16 else 4 - upbound = factor * 1024 * 4 - if config.hidden_size > upbound: - # as required by ln backward kernel currently - raise Exception(f"hidden_size > {upbound}") - - head_dim = config.hidden_size // config.nhead - if head_dim % factor != 0: - # as required by reshape kernel - raise Exception(f"head_dim({head_dim}) % {factor} != 0") - - -def calc_offset(sizes): - offsets = [0] - tmp = 0 - for x in sizes: - tmp += x - offsets.append(tmp) - return offsets - - -colossal_multihead_attention = None - - -@dataclass -class Config: - max_batch_tokens: int # max batch token numbers - max_seq_len: int # max sequence length - hidden_size: int # size of transformer hidden layers - nhead: int # number of heads in attention - attn_prob_dropout_ratio: float # attention score dropout ratio - hidden_dropout_ratio: float # dropout ration before residual - norm_first: bool # norm_first - fp16: bool # fp16 precision - - -class MultiHeadAttention1DFunc(Function): - @staticmethod - def forward( - ctx, - input, - input_mask, - in_proj_weight, - in_proj_bias, - out_proj_weight, - out_proj_bias, - norm_weight, - norm_bias, - config, - ): - cuda_module = colossal_multihead_attention - forward_func = ( - cuda_module.multihead_attention_fw_fp16 if config.fp16 else cuda_module.multihead_attention_fw_fp32 - ) - if config.fp16: - input = input.to(torch.half) - input_mask = input_mask.to(torch.half) - - (output,) = forward_func( - config.layer_id, - input, - input_mask, - in_proj_weight, - in_proj_bias, - out_proj_weight, - out_proj_bias, - norm_weight, - norm_bias, - config.training, - config.norm_first, - ) - - if config.is_grad_enabled and config.training: - ctx.save_for_backward( - output, - input, - input_mask, - in_proj_weight, - in_proj_bias, - out_proj_weight, - out_proj_bias, - norm_weight, - norm_bias, - ) - ctx.config = config - return output - - @staticmethod - def backward(ctx, grad_output): - assert ctx.config.training - - cuda_module = colossal_multihead_attention - backward_func = ( - cuda_module.multihead_attention_bw_fp16 if ctx.config.fp16 else cuda_module.multihead_attention_bw_fp32 - ) - - ( - output, - input, - input_mask, - in_proj_weight, - in_proj_bias, - out_proj_weight, - out_proj_bias, - norm_weight, - norm_bias, - ) = ctx.saved_tensors - - grad_input = None - grad_in_proj_weight = None - grad_in_proj_bias = None - grad_out_proj_weight = None - grad_out_proj_bias = None - grad_norm_weight = None - grad_norm_bias = None - - if ctx.config.fp16: - grad_output = grad_output.to(torch.half) - output = output.to(torch.half) - input = input.to(torch.half) - input_mask = input_mask.to(torch.half) - ( - grad_input, - grad_in_proj_weight, - grad_in_proj_bias, - grad_out_proj_weight, - grad_out_proj_bias, - grad_norm_weight, - grad_norm_bias, - ) = backward_func( - ctx.config.layer_id, - grad_output, - output, - input, - input_mask, - in_proj_weight, - in_proj_bias, - out_proj_weight, - out_proj_bias, - norm_weight, - norm_bias, - ) - - return ( - grad_input, - None, - grad_in_proj_weight, - grad_in_proj_bias, - grad_out_proj_weight, - grad_out_proj_bias, - grad_norm_weight, - grad_norm_bias, - None, - ) - - -class MultiHeadAttention(nn.Module): - """Initialize the MultiHeadAttention. - - Static variable: - - layer_id: The layer-index counter starting from 0 and incrementing by 1 every time a layer object is instantiated, - e.g. if a model has 24 transformer layers, layer_id goes from 0 to 23. - - Arguments: - hidden_size: Total dimension of hidden_size. - nhead: Number of parallel attention heads. - batch_size: Batch Size for one forward - max_seq_len: Max length of input sequence - dropout: Dropout probability - norm_first: perform LayerNorms before attention - """ - - layer_id = 0 - - def __init__(self, hidden_size, nhead, batch_size, max_seq_len, dropout=0.0, norm_first=False, fp16=True, pg=None): - super(MultiHeadAttention, self).__init__() - - self.config = Config( - batch_size * max_seq_len, max_seq_len, hidden_size, nhead, dropout, dropout, norm_first, fp16 - ) - check_config(self.config) - self.pg = pg - self.pg_size = 1 - if self.pg: - self.pg_size = pg.size() - self.config.layer_id = MultiHeadAttention.layer_id - MultiHeadAttention.layer_id = MultiHeadAttention.layer_id + 1 - - # Load cuda modules if needed - global colossal_multihead_attention - if colossal_multihead_attention is None: - from colossalai.kernel.op_builder import MultiHeadAttnBuilder - - multihead_attention = MultiHeadAttnBuilder().load() - colossal_multihead_attention = multihead_attention - - # create the layer in cuda kernels. - cuda_module = colossal_multihead_attention - create_layer_func = ( - cuda_module.create_multihead_attention_fp16 - if self.config.fp16 - else cuda_module.create_multihead_attention_fp32 - ) - - create_layer_func( - self.config.layer_id, - self.config.max_batch_tokens, - self.config.max_seq_len, - self.config.hidden_size, - self.config.nhead, - self.config.attn_prob_dropout_ratio, - self.config.hidden_dropout_ratio, - self.config.norm_first, - self.pg, - ) - - hs = self.config.hidden_size - - self.precision = torch.float32 - if self.config.fp16: - self.precision = torch.half - - self.hs_per_rank = int(hs / self.pg_size) - - self.in_proj_weight = nn.Parameter(torch.Tensor(3, self.hs_per_rank, hs)) - self.in_proj_bias = nn.Parameter(torch.Tensor(3, self.hs_per_rank)) - self.out_proj_weight = nn.Parameter(torch.Tensor(hs, self.hs_per_rank)) - self.out_proj_bias = nn.Parameter(torch.Tensor(hs)) - self.norm_weight = nn.Parameter(torch.Tensor(hs)) - self.norm_bias = nn.Parameter(torch.Tensor(hs)) - - self.reset_parameters() - torch.cuda.empty_cache() - - def calc_bound(self, w): - fan_in, _ = nn.init._calculate_fan_in_and_fan_out(w) - bound = 1.0 / math.sqrt(fan_in) - return bound - - def reset_parameters(self): - hs = self.config.hidden_size - - nn.init.zeros_(self.out_proj_bias) - - nn.init.ones_(self.norm_weight) - nn.init.zeros_(self.norm_bias) - - if self.pg_size > 1: - rank_in_pg = torch.distributed.get_rank(self.pg) - attn_qkvw_global = torch.empty(hs * 3, hs) - attn_qkvb_global = torch.empty(hs * 3) - nn.init.xavier_uniform_(attn_qkvw_global, 1.0 / math.sqrt(2.0)) - bound = self.calc_bound(attn_qkvw_global) - nn.init.uniform_(attn_qkvb_global, -bound, bound) - - attn_qkvw_global = attn_qkvw_global.cuda() - attn_qkvb_global = attn_qkvb_global.cuda() - torch.distributed.broadcast(attn_qkvw_global, src=0, group=self.pg) - torch.distributed.broadcast(attn_qkvb_global, src=0, group=self.pg) - attn_qkvw_global = attn_qkvw_global.cpu() - attn_qkvb_global = attn_qkvb_global.cpu() - - with torch.no_grad(): - self.in_proj_weight.copy_( - attn_qkvw_global.view(3, hs, hs)[ - :, int(hs * rank_in_pg / self.pg_size) : int(hs * (rank_in_pg + 1) / self.pg_size), : - ] - ) - self.in_proj_bias.copy_( - attn_qkvb_global.view(3, hs)[ - :, int(hs * rank_in_pg / self.pg_size) : int(hs * (rank_in_pg + 1) / self.pg_size) - ] - ) - - attn_ow_global = torch.empty(hs, hs) - nn.init.xavier_uniform_(attn_ow_global, 1.0) - attn_ow_global = attn_ow_global.cuda() - torch.distributed.broadcast(attn_ow_global, src=0, group=self.pg) - attn_ow_global = attn_ow_global.cpu() - with torch.no_grad(): - self.out_proj_weight.copy_( - attn_ow_global[:, int(hs * rank_in_pg / self.pg_size) : int(hs * (rank_in_pg + 1) / self.pg_size)] - ) - - else: - attn_qkvw = self.in_proj_weight.view(-1, hs) - nn.init.xavier_uniform_(attn_qkvw, 1.0 / math.sqrt(2.0)) - bound = self.calc_bound(attn_qkvw) - nn.init.uniform_(self.in_proj_bias, -bound, bound) - - nn.init.xavier_uniform_(self.out_proj_weight, 1.0) - - def state_dict(self, destination=None, prefix="", keep_vars=False): - destination = torch.nn.Module.state_dict(self, destination=destination, prefix=prefix, keep_vars=keep_vars) - return destination - - def forward(self, hidden_states, encoder_padding_mask): - self.config.training = self.training - self.config.is_grad_enabled = torch.is_grad_enabled() - hidden_states = hidden_states.contiguous() - encoder_padding_mask = (encoder_padding_mask * -1e8).type_as(hidden_states).contiguous() - - bs, sl, dim = hidden_states.size() - if bs * sl > self.config.max_batch_tokens: - raise ValueError(f"Batch token numbers {bs * sl} exceeds the limit {self.config.max_batch_tokens}.") - if sl > self.config.max_seq_len: - raise ValueError(f"Sequence length {sl} exceeds the limit {self.config.max_seq_len}.") - if len(encoder_padding_mask.size()) == 1: - assert bs == 1 and sl == encoder_padding_mask.size(0) - else: - assert bs == encoder_padding_mask.size(0) and sl == encoder_padding_mask.size(1) - - output = MultiHeadAttention1DFunc.apply( - hidden_states, - encoder_padding_mask, - self.in_proj_weight, - self.in_proj_bias, - self.out_proj_weight, - self.out_proj_bias, - self.norm_weight, - self.norm_bias, - self.config, - ) - - return output.to(self.precision) diff --git a/colossalai/kernel/extensions b/colossalai/kernel/extensions new file mode 120000 index 000000000..e8eb45a54 --- /dev/null +++ b/colossalai/kernel/extensions @@ -0,0 +1 @@ +../../extensions \ No newline at end of file diff --git a/colossalai/kernel/jit/option.py b/colossalai/kernel/jit/option.py index 8bebad894..d392649a6 100644 --- a/colossalai/kernel/jit/option.py +++ b/colossalai/kernel/jit/option.py @@ -1,7 +1,7 @@ import torch +from colossalai.accelerator import get_accelerator from colossalai.legacy.nn.layer.colossalai_layer import Embedding, Linear -from colossalai.utils import get_current_device from .bias_dropout_add import bias_dropout_add_fused_train from .bias_gelu import bias_gelu_impl @@ -46,11 +46,13 @@ def warmup_jit_fusion( ): """Compile JIT functions before the main training steps""" - embed = Embedding(vocab_size, hidden_size).to(get_current_device()) - linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_current_device()) - linear_2 = Linear(hidden_size * 4, hidden_size, skip_bias_add=True).to(get_current_device()) + embed = Embedding(vocab_size, hidden_size).to(get_accelerator().get_current_device()) + linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_accelerator().get_current_device()) + linear_2 = Linear(hidden_size * 4, hidden_size, skip_bias_add=True).to(get_accelerator().get_current_device()) - x = torch.randint(vocab_size, (batch_size, seq_length), dtype=torch.long, device=get_current_device()) + x = torch.randint( + vocab_size, (batch_size, seq_length), dtype=torch.long, device=get_accelerator().get_current_device() + ) x = embed(x) y, y_bias = linear_1(x) z, z_bias = linear_2(y) @@ -58,8 +60,8 @@ def warmup_jit_fusion( # prop and recomputation for bias_grad, input_grad in zip([True, True], [False, True]): for _ in range(10): - bias = torch.rand_like(y_bias, dtype=dtype, device=get_current_device()) - input_ = torch.rand_like(y, dtype=dtype, device=get_current_device()) + bias = torch.rand_like(y_bias, dtype=dtype, device=get_accelerator().get_current_device()) + input_ = torch.rand_like(y, dtype=dtype, device=get_accelerator().get_current_device()) bias.requires_grad, input_.requires_grad = bias_grad, input_grad bias_gelu_impl(input_, bias) @@ -69,9 +71,9 @@ def warmup_jit_fusion( # prop and recomputation for input_grad, bias_grad, residual_grad in zip([False, True], [True, True], [True, True]): for _ in range(10): - input_ = torch.rand_like(z, dtype=dtype, device=get_current_device()) - residual = torch.rand_like(x, dtype=dtype, device=get_current_device()) - bias = torch.rand_like(z_bias, dtype=dtype, device=get_current_device()) + input_ = torch.rand_like(z, dtype=dtype, device=get_accelerator().get_current_device()) + residual = torch.rand_like(x, dtype=dtype, device=get_accelerator().get_current_device()) + bias = torch.rand_like(z_bias, dtype=dtype, device=get_accelerator().get_current_device()) input_.requires_grad = input_grad bias.requires_grad = bias_grad residual.requires_grad = residual_grad diff --git a/colossalai/kernel/kernel_loader.py b/colossalai/kernel/kernel_loader.py new file mode 100644 index 000000000..148c3e3fc --- /dev/null +++ b/colossalai/kernel/kernel_loader.py @@ -0,0 +1,109 @@ +import warnings +from typing import List + +from .extensions import ( + CpuAdamArmExtension, + CpuAdamX86Extension, + FlashAttentionDaoCudaExtension, + FlashAttentionNpuExtension, + FlashAttentionXformersCudaExtension, + FusedOptimizerCudaExtension, + LayerNormCudaExtension, + MoeCudaExtension, + ScaledMaskedSoftmaxCudaExtension, + ScaledUpperTriangleMaskedSoftmaxCudaExtension, +) +from .extensions.base_extension import _Extension + +__all__ = [ + "KernelLoader", + "CPUAdamLoader", + "LayerNormLoader", + "MoeLoader", + "FusedOptimizerLoader", + "ScaledMaskedSoftmaxLoader", + "ScaledUpperTriangleMaskedSoftmaxLoader", +] + + +class KernelLoader: + """ + An abstract class which offers encapsulation to the kernel loading process. + + Usage: + kernel_loader = KernelLoader() + kernel = kernel_loader.load() + """ + + REGISTRY: List[_Extension] = [] + + @classmethod + def register_extension(cls, extension: _Extension): + """ + This classmethod is an extension point which allows users to register their customized + kernel implementations to the loader. + + Args: + extension (_Extension): the extension to be registered. + """ + cls.REGISTRY.append(extension) + + def load(self, ext_name: str = None): + """ + Load the kernel according to the current machine. + + Args: + ext_name (str): the name of the extension to be loaded. If not specified, the loader + will try to look for an kernel available on the current machine. + """ + exts = [ext_cls() for ext_cls in self.__class__.REGISTRY] + + # look for exts which can be built/loaded on the current machine + + if ext_name: + usable_exts = list(filter(lambda ext: ext.name == ext_name, exts)) + else: + usable_exts = [] + for ext in exts: + if ext.is_hardware_available(): + # make sure the machine is compatible during kernel loading + ext.assert_hardware_compatible() + usable_exts.append(ext) + + assert len(usable_exts) != 0, f"No usable kernel found for {self.__class__.__name__} on the current machine." + + if len(usable_exts) > 1: + # if more than one usable kernel is found, we will try to load the kernel with the highest priority + usable_exts = sorted(usable_exts, key=lambda ext: ext.priority, reverse=True) + warnings.warn( + f"More than one kernel is available, loading the kernel with the highest priority - {usable_exts[0].__class__.__name__}" + ) + return usable_exts[0].load() + + +class CPUAdamLoader(KernelLoader): + REGISTRY = [CpuAdamX86Extension, CpuAdamArmExtension] + + +class LayerNormLoader(KernelLoader): + REGISTRY = [LayerNormCudaExtension] + + +class MoeLoader(KernelLoader): + REGISTRY = [MoeCudaExtension] + + +class FusedOptimizerLoader(KernelLoader): + REGISTRY = [FusedOptimizerCudaExtension] + + +class ScaledMaskedSoftmaxLoader(KernelLoader): + REGISTRY = [ScaledMaskedSoftmaxCudaExtension] + + +class ScaledUpperTriangleMaskedSoftmaxLoader(KernelLoader): + REGISTRY = [ScaledUpperTriangleMaskedSoftmaxCudaExtension] + + +class FlashAttentionLoader(KernelLoader): + REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension, FlashAttentionXformersCudaExtension] diff --git a/colossalai/kernel/op_builder b/colossalai/kernel/op_builder deleted file mode 120000 index db4f9c335..000000000 --- a/colossalai/kernel/op_builder +++ /dev/null @@ -1 +0,0 @@ -../../op_builder \ No newline at end of file diff --git a/colossalai/legacy/amp/naive_amp/_fp16_optimizer.py b/colossalai/legacy/amp/naive_amp/_fp16_optimizer.py index 97ec57fbd..d2dceb50b 100644 --- a/colossalai/legacy/amp/naive_amp/_fp16_optimizer.py +++ b/colossalai/legacy/amp/naive_amp/_fp16_optimizer.py @@ -7,7 +7,7 @@ from torch.distributed import ProcessGroup from torch.optim import Optimizer from colossalai.amp.naive_amp.grad_scaler import BaseGradScaler -from colossalai.kernel.op_builder import FusedOptimBuilder +from colossalai.kernel.kernel_loader import FusedOptimizerLoader from colossalai.legacy.context import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.utils import clip_grad_norm_fp32, copy_tensor_parallel_attributes @@ -28,7 +28,7 @@ def load_fused_optim(): global fused_optim if fused_optim is None: - fused_optim = FusedOptimBuilder().load() + fused_optim = FusedOptimizerLoader().load() def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None): diff --git a/colossalai/legacy/amp/torch_amp/torch_amp.py b/colossalai/legacy/amp/torch_amp/torch_amp.py index 0a8d09be2..08f867eee 100644 --- a/colossalai/legacy/amp/torch_amp/torch_amp.py +++ b/colossalai/legacy/amp/torch_amp/torch_amp.py @@ -1,18 +1,19 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from colossalai.utils.device import autocast - import torch.nn as nn from torch import Tensor from torch.nn.modules.loss import _Loss from torch.optim import Optimizer +from colossalai.accelerator import get_accelerator from colossalai.interface import OptimizerWrapper from colossalai.legacy.utils import clip_grad_norm_fp32 from ._grad_scaler import GradScaler +autocast = get_accelerator().autocast + class TorchAMPOptimizer(OptimizerWrapper): """A wrapper class which integrate Pytorch AMP with an optimizer diff --git a/colossalai/legacy/communication/p2p.py b/colossalai/legacy/communication/p2p.py index 19c3919b6..cf0bd4ba2 100644 --- a/colossalai/legacy/communication/p2p.py +++ b/colossalai/legacy/communication/p2p.py @@ -8,9 +8,9 @@ from typing import List, Tuple, Union import torch import torch.distributed as dist +from colossalai.accelerator import get_accelerator from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc -from colossalai.utils import get_current_device from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks @@ -43,12 +43,16 @@ def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) -> def create_recv_buffer_with_shapes(recv_shapes, dtype, scatter_gather_tensors): if isinstance(recv_shapes, torch.Size): recv_chunk_shape, recv_split = _get_tensor_shape(recv_shapes, scatter_gather_tensors) - buffer_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype) + buffer_recv = torch.empty( + recv_chunk_shape, requires_grad=True, device=get_accelerator().get_current_device(), dtype=dtype + ) return buffer_recv, recv_split buffer_recv = [] for recv_shape in recv_shapes: recv_chunk_shape, recv_split = _get_tensor_shape(recv_shape, scatter_gather_tensors) - tensor_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype) + tensor_recv = torch.empty( + recv_chunk_shape, requires_grad=True, device=get_accelerator().get_current_device(), dtype=dtype + ) buffer_recv.append(tensor_recv) return buffer_recv, recv_split diff --git a/colossalai/legacy/communication/ring.py b/colossalai/legacy/communication/ring.py index a61dae56c..792a15abd 100644 --- a/colossalai/legacy/communication/ring.py +++ b/colossalai/legacy/communication/ring.py @@ -3,9 +3,9 @@ import torch +from colossalai.accelerator import get_accelerator from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc -from colossalai.utils import get_current_device, synchronize def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) -> torch.Tensor: @@ -29,7 +29,7 @@ def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) -> current_rank = gpc.get_global_rank() tensor_recv_prev = torch.empty( - buffer_shape, requires_grad=True, device=get_current_device(), dtype=tensor_send_next.dtype + buffer_shape, requires_grad=True, device=get_accelerator().get_current_device(), dtype=tensor_send_next.dtype ) # send to next rank @@ -52,6 +52,6 @@ def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) -> req.wait() # To protect against race condition when using batch_isend_irecv(). - synchronize() + get_accelerator().synchronize() return tensor_recv_prev diff --git a/colossalai/legacy/communication/utils.py b/colossalai/legacy/communication/utils.py index 6d77f3753..0b7c0eb74 100644 --- a/colossalai/legacy/communication/utils.py +++ b/colossalai/legacy/communication/utils.py @@ -3,9 +3,9 @@ from typing import List, Tuple, Union import torch import torch.distributed as dist +from colossalai.accelerator import get_accelerator from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc -from colossalai.utils import get_current_device TensorShape = Union[torch.Size, List[int], Tuple[int]] @@ -35,7 +35,7 @@ def send_obj_meta(obj, need_meta=True, next_rank=None) -> bool: if next_rank is None: next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) - tensor_kwargs = {"dtype": torch.long, "device": get_current_device()} + tensor_kwargs = {"dtype": torch.long, "device": get_accelerator().get_current_device()} if isinstance(obj, torch.Tensor): send_obj_nums = torch.tensor(1, **tensor_kwargs) dist.send(send_obj_nums, next_rank) @@ -74,7 +74,7 @@ def recv_obj_meta(obj_shape, prev_rank=None) -> torch.Size: if prev_rank is None: prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) - tensor_kwargs = {"dtype": torch.long, "device": get_current_device()} + tensor_kwargs = {"dtype": torch.long, "device": get_accelerator().get_current_device()} recv_obj_nums = torch.empty((), **tensor_kwargs) dist.recv(recv_obj_nums, prev_rank) if recv_obj_nums.item() == 1: diff --git a/colossalai/legacy/engine/schedule/_base_schedule.py b/colossalai/legacy/engine/schedule/_base_schedule.py index 4a3ccfda1..9b2913442 100644 --- a/colossalai/legacy/engine/schedule/_base_schedule.py +++ b/colossalai/legacy/engine/schedule/_base_schedule.py @@ -6,8 +6,8 @@ from typing import Callable, Iterable import torch +from colossalai.accelerator import get_accelerator from colossalai.logging import get_dist_logger -from colossalai.utils import get_current_device class BaseSchedule(ABC): @@ -29,12 +29,12 @@ class BaseSchedule(ABC): def _move_tensor(element): if torch.is_tensor(element): if not element.is_cuda: - return element.to(get_current_device()).detach() + return element.to(get_accelerator().get_current_device()).detach() return element def _move_to_device(self, data): if isinstance(data, torch.Tensor): - data = data.to(get_current_device()) + data = data.to(get_accelerator().get_current_device()) elif isinstance(data, (list, tuple)): data_to_return = [] for element in data: diff --git a/colossalai/legacy/engine/schedule/_pipeline_schedule.py b/colossalai/legacy/engine/schedule/_pipeline_schedule.py index 5fd5602e7..4a23853c1 100644 --- a/colossalai/legacy/engine/schedule/_pipeline_schedule.py +++ b/colossalai/legacy/engine/schedule/_pipeline_schedule.py @@ -7,12 +7,12 @@ from typing import Callable, List, Tuple, Union import torch.cuda import colossalai.legacy.communication as comm +from colossalai.accelerator import get_accelerator from colossalai.legacy.amp.naive_amp import NaiveAMPModel from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.utils import switch_virtual_pipeline_parallel_rank from colossalai.logging import get_dist_logger -from colossalai.utils.device import get_current_device from ._base_schedule import BaseSchedule @@ -352,7 +352,7 @@ class PipelineSchedule(BaseSchedule): output_objs = [] return_tensors = [] if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): - accum_loss = torch.zeros(1, device=get_current_device()) + accum_loss = torch.zeros(1, device=get_accelerator().get_current_device()) else: accum_loss = None # Used for tensor meta information communication @@ -584,7 +584,7 @@ class InterleavedPipelineSchedule(PipelineSchedule): if not forward_only: output_obj_grads = [[] for _ in range(len(model))] if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): - accum_loss = torch.zeros(1, device=get_current_device()) + accum_loss = torch.zeros(1, device=get_accelerator().get_current_device()) else: accum_loss = None diff --git a/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py b/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py index 4cd7e47c3..6e7760218 100644 --- a/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py +++ b/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py @@ -6,10 +6,10 @@ from typing import Iterable, Tuple import torch.cuda import colossalai.legacy.communication.p2p_v2 as comm +from colossalai.accelerator import get_accelerator from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.engine import Engine -from colossalai.utils.device import get_current_device from ._pipeline_schedule import PipelineSchedule @@ -99,7 +99,7 @@ class PipelineScheduleV2(PipelineSchedule): output_objs = [] return_tensors = [] if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): - accum_loss = torch.zeros(1, device=get_current_device()) + accum_loss = torch.zeros(1, device=get_accelerator().get_current_device()) else: accum_loss = None diff --git a/colossalai/legacy/initialize.py b/colossalai/legacy/initialize.py index 4035bd6b5..d99a7d3f0 100644 --- a/colossalai/legacy/initialize.py +++ b/colossalai/legacy/initialize.py @@ -15,6 +15,7 @@ from torch.optim.lr_scheduler import _LRScheduler from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader +from colossalai.accelerator import get_accelerator from colossalai.context import Config, ConfigException from colossalai.interface import OptimizerWrapper from colossalai.legacy.amp import AMP_TYPE, convert_to_amp @@ -34,7 +35,6 @@ from colossalai.legacy.utils import is_using_ddp, is_using_pp, is_using_sequence from colossalai.legacy.zero import ShardedOptimizerV2, convert_to_zero_v2 from colossalai.legacy.zero.gemini.ophooks import BaseOpHook from colossalai.logging import get_dist_logger -from colossalai.utils import get_current_device def get_default_parser(): @@ -309,9 +309,9 @@ def initialize( else: if isinstance(model, nn.Module): # first sync model across dp ranks - model.to(get_current_device()) + model.to(get_accelerator().get_current_device()) elif isinstance(model, Callable): - model = model().to(get_current_device()) + model = model().to(get_accelerator().get_current_device()) # optimizer maybe a optimizer_cls if isinstance(optimizer, Callable): diff --git a/colossalai/legacy/nn/layer/colossalai_layer/embedding.py b/colossalai/legacy/nn/layer/colossalai_layer/embedding.py index e1db0fe98..aa661664f 100644 --- a/colossalai/legacy/nn/layer/colossalai_layer/embedding.py +++ b/colossalai/legacy/nn/layer/colossalai_layer/embedding.py @@ -3,8 +3,8 @@ from typing import Callable from torch import dtype, nn +from colossalai.accelerator import get_accelerator from colossalai.nn import init -from colossalai.utils import get_current_device from ..parallel_1d import Embedding1D, PatchEmbedding1D, VocabParallelEmbedding1D from ..parallel_2d import Embedding2D, PatchEmbedding2D, VocabParallelEmbedding2D @@ -83,7 +83,7 @@ class Embedding(ColossalaiModule): embed = ( nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx, *args, **kwargs) .to(dtype) - .to(get_current_device()) + .to(get_accelerator().get_current_device()) ) weight_initializer(embed.weight, fan_in=num_embeddings, fan_out=embedding_dim) elif num_embeddings <= vocab_parallel_limit: diff --git a/colossalai/legacy/nn/layer/colossalai_layer/normalization.py b/colossalai/legacy/nn/layer/colossalai_layer/normalization.py index f8e317e72..58842f481 100644 --- a/colossalai/legacy/nn/layer/colossalai_layer/normalization.py +++ b/colossalai/legacy/nn/layer/colossalai_layer/normalization.py @@ -1,6 +1,6 @@ from torch import nn -from colossalai.utils import get_current_device +from colossalai.accelerator import get_accelerator from ..parallel_1d import LayerNorm1D from ..parallel_2d import LayerNorm2D @@ -36,7 +36,7 @@ class LayerNorm(ColossalaiModule): def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None) -> None: tensor_parallel = get_tensor_parallel_mode() if tensor_parallel is None: - norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_current_device()) + norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_accelerator().get_current_device()) else: norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype) super().__init__(norm) diff --git a/colossalai/legacy/nn/layer/parallel_1d/layers.py b/colossalai/legacy/nn/layer/parallel_1d/layers.py index b6ec5347f..b38e1c433 100644 --- a/colossalai/legacy/nn/layer/parallel_1d/layers.py +++ b/colossalai/legacy/nn/layer/parallel_1d/layers.py @@ -10,7 +10,7 @@ import torch.nn.functional as F from torch import Tensor from torch.nn.parameter import Parameter -from colossalai.kernel import LayerNorm +from colossalai.accelerator import get_accelerator from colossalai.legacy.communication import broadcast from colossalai.legacy.context import ParallelMode, seed from colossalai.legacy.context.parallel_context import global_context as gpc @@ -22,7 +22,7 @@ from colossalai.legacy.utils.checkpointing import ( partition_tensor_parallel_state_dict, ) from colossalai.nn import init as init -from colossalai.utils.device import get_current_device +from colossalai.nn.layer.layernorm import MixedFusedLayerNorm as LayerNorm from ..base_layer import ParallelLayer from ..colossalai_layer._utils import ColossalaiModule @@ -221,7 +221,7 @@ class Classifier1D(ParallelLayer): # Parameters. # Initialize weight. - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} if weight is not None: self.weight = weight self.has_weight = False @@ -357,7 +357,7 @@ class VocabParallelClassifier1D(ParallelLayer): # Parameters. # Initialize weight. - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} if weight is not None: self.weight = weight self.has_weight = False @@ -499,7 +499,7 @@ class Linear1D_Col(ParallelLayer): # Parameters. # Initialize weight. - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs)) if bias: @@ -638,7 +638,7 @@ class Linear1D_Row(ParallelLayer): # Parameters. # Initialize weight. - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs)) if self.stream_chunk_num > 1: @@ -802,7 +802,9 @@ class Embedding1D(ParallelLayer): self.embed_kwargs = kwargs self.weight = Parameter( - torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype) + torch.empty( + (num_embeddings, embed_dim_per_partition), device=get_accelerator().get_current_device(), dtype=dtype + ) ) self.reset_parameters(weight_initializer) @@ -912,7 +914,11 @@ class VocabParallelEmbedding1D(ParallelLayer): self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition self.weight = Parameter( - torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=get_current_device(), dtype=dtype) + torch.empty( + (self.num_embeddings_per_partition, self.embed_dim), + device=get_accelerator().get_current_device(), + dtype=dtype, + ) ) self.reset_parameters(weight_initializer) diff --git a/colossalai/legacy/nn/layer/parallel_2d/_operation.py b/colossalai/legacy/nn/layer/parallel_2d/_operation.py index f1eff7128..f67ee2e60 100644 --- a/colossalai/legacy/nn/layer/parallel_2d/_operation.py +++ b/colossalai/legacy/nn/layer/parallel_2d/_operation.py @@ -5,10 +5,10 @@ import torch.distributed as dist from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd +from colossalai.accelerator import get_accelerator from colossalai.legacy.communication.collective import all_gather, all_reduce, reduce_scatter from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc -from colossalai.utils import get_current_device def matmul_2d( @@ -250,7 +250,7 @@ class Matmul_AB_2D(torch.autograd.Function): B_shape = B.shape B = B.reshape((-1, B_shape[-1])) C_shape = (A.shape[0], B.shape[-1]) - C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device()) + C = torch.zeros(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device()) # use circular buffer to store the communication tensor # 2 is enough for all cases @@ -399,7 +399,7 @@ class Matmul_ABT_2D(torch.autograd.Function): B_shape = B.shape B = B.reshape((-1, B_shape[-1])) C_shape = (A.shape[0], B.shape[0]) - C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) + C = torch.empty(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device()) # use circular buffer to store the communication tensor # 2 is enough for all cases @@ -556,7 +556,7 @@ class Matmul_ATB_2D(torch.autograd.Function): B_shape = B.shape B = B.reshape((-1, B_shape[-1])) C_shape = (A.shape[-1], B.shape[-1]) - C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) + C = torch.empty(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device()) # use circular buffer to store the communication tensor # 2 is enough for all cases diff --git a/colossalai/legacy/nn/layer/parallel_2d/layers.py b/colossalai/legacy/nn/layer/parallel_2d/layers.py index f81c5334a..4987afa18 100644 --- a/colossalai/legacy/nn/layer/parallel_2d/layers.py +++ b/colossalai/legacy/nn/layer/parallel_2d/layers.py @@ -8,6 +8,7 @@ import torch.nn.functional as F from torch import Tensor from torch.nn import Parameter +from colossalai.accelerator import get_accelerator from colossalai.legacy.communication import broadcast from colossalai.legacy.context import ParallelMode, seed from colossalai.legacy.core import global_context as gpc @@ -18,7 +19,6 @@ from colossalai.legacy.utils.checkpointing import ( partition_tensor_parallel_state_dict, ) from colossalai.nn import init as init -from colossalai.utils.device import get_current_device from ..base_layer import ParallelLayer from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple @@ -82,7 +82,7 @@ class Linear2D(ParallelLayer): self.hidden_size_per_partition = divide(self.out_features, self.summa_dim) # create weight, shape: [k/q, h/q] - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} self.weight = Parameter( torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs) ) @@ -259,7 +259,7 @@ class LayerNorm2D(ParallelLayer): self.partitioned_partition = divide(normalized_shape, self.summa_dim**2) # create parameters - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs)) if bias: @@ -438,18 +438,24 @@ class PatchEmbedding2D(ParallelLayer): self.weight = Parameter( torch.empty( (self.embed_size_per_partition, in_chans, *self.patch_size), - device=get_current_device(), + device=get_accelerator().get_current_device(), dtype=dtype, ) ) - self.bias = Parameter(torch.empty(self.embed_size_per_partition, device=get_current_device(), dtype=dtype)) + self.bias = Parameter( + torch.empty(self.embed_size_per_partition, device=get_accelerator().get_current_device(), dtype=dtype) + ) self.cls_token = Parameter( - torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype) + torch.zeros( + (1, 1, self.embed_size_per_partition), device=get_accelerator().get_current_device(), dtype=dtype + ) ) self.pos_embed = Parameter( torch.zeros( - (1, self.num_patches + 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype + (1, self.num_patches + 1, self.embed_size_per_partition), + device=get_accelerator().get_current_device(), + dtype=dtype, ) ) @@ -619,7 +625,9 @@ class Embedding2D(ParallelLayer): self.embed_kwargs = kwargs self.weight = Parameter( - torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype) + torch.empty( + (num_embeddings, embed_dim_per_partition), device=get_accelerator().get_current_device(), dtype=dtype + ) ) self.reset_parameters(weight_initializer) @@ -758,7 +766,7 @@ class VocabParallelEmbedding2D(ParallelLayer): self.weight = Parameter( torch.empty( (self.num_embeddings_per_partition, self.embed_dim_per_partition), - device=get_current_device(), + device=get_accelerator().get_current_device(), dtype=dtype, ) ) @@ -895,11 +903,18 @@ class Classifier2D(ParallelLayer): self.has_weight = False else: self.weight = Parameter( - torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype) + torch.empty( + self.num_classes, + self.input_size_per_partition, + device=get_accelerator().get_current_device(), + dtype=dtype, + ) ) self.has_weight = True if bias: - self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) + self.bias = Parameter( + torch.zeros(self.num_classes, device=get_accelerator().get_current_device(), dtype=dtype) + ) else: self.bias = None @@ -1052,7 +1067,7 @@ class VocabParallelClassifier2D(ParallelLayer): self.output_size_per_partition = divide(num_classes, self.summa_dim) # create weight, shape: [k/q, h/q] - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} if weight is not None: self.weight = weight self.has_weight = False diff --git a/colossalai/legacy/nn/layer/parallel_2p5d/_operation.py b/colossalai/legacy/nn/layer/parallel_2p5d/_operation.py index 50900c135..43328bd03 100644 --- a/colossalai/legacy/nn/layer/parallel_2p5d/_operation.py +++ b/colossalai/legacy/nn/layer/parallel_2p5d/_operation.py @@ -5,10 +5,10 @@ import torch.distributed as dist from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd +from colossalai.accelerator import get_accelerator from colossalai.legacy.communication.collective import all_gather, all_reduce, reduce_scatter from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc -from colossalai.utils import get_current_device def get_parallel_group(parallel_mode: ParallelMode): @@ -205,7 +205,7 @@ class Matmul_AB_2p5D(torch.autograd.Function): B_shape = B.shape B = B.reshape((-1, B_shape[-1])) C_shape = (A.shape[0], B.shape[-1]) - C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device()) + C = torch.zeros(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device()) # use circular buffer to store the communication tensor # 2 is enough for all cases @@ -362,7 +362,7 @@ class Matmul_ABT_2p5D(torch.autograd.Function): B_shape = B.shape B = B.reshape((-1, B_shape[-1])) C_shape = (A.shape[0], B.shape[0]) - C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) + C = torch.empty(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device()) # use circular buffer to store the communication tensor # 2 is enough for all cases @@ -527,7 +527,7 @@ class Matmul_ATB_2p5D(torch.autograd.Function): B_shape = B.shape B = B.reshape((-1, B_shape[-1])) C_shape = (A.shape[-1], B.shape[-1]) - C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) + C = torch.empty(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device()) # use circular buffer to store the communication tensor # 2 is enough for all cases @@ -661,7 +661,9 @@ class _Add_Bias_2p5D(torch.autograd.Function): if row_rank == 0: bias_temp = bias.clone() else: - bias_temp = torch.zeros(output_size_per_partition, dtype=bias.dtype, device=get_current_device()) + bias_temp = torch.zeros( + output_size_per_partition, dtype=bias.dtype, device=get_accelerator().get_current_device() + ) src_rank = ( col_rank + dep_rank * tesseract_dim**2 @@ -984,7 +986,7 @@ class SplitFirst(torch.autograd.Function): @custom_bwd def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: grad_shape = (ctx.batch_size,) + output_grad.shape[1:] - grad = torch.empty(grad_shape, dtype=output_grad.dtype, device=get_current_device()) + grad = torch.empty(grad_shape, dtype=output_grad.dtype, device=get_accelerator().get_current_device()) dist.all_gather( list(grad.chunk(ctx.tesseract_dim, dim=0)), output_grad.contiguous(), group=gpc.get_group(ctx.para_mode) ) diff --git a/colossalai/legacy/nn/layer/parallel_2p5d/layers.py b/colossalai/legacy/nn/layer/parallel_2p5d/layers.py index b451a4031..d9410f1cb 100644 --- a/colossalai/legacy/nn/layer/parallel_2p5d/layers.py +++ b/colossalai/legacy/nn/layer/parallel_2p5d/layers.py @@ -8,6 +8,7 @@ import torch.nn.functional as F from torch import Tensor from torch.nn import Parameter +from colossalai.accelerator import get_accelerator from colossalai.legacy.communication import broadcast from colossalai.legacy.context import ParallelMode, seed from colossalai.legacy.core import global_context as gpc @@ -19,7 +20,6 @@ from colossalai.legacy.utils.checkpointing import ( partition_tensor_parallel_state_dict, ) from colossalai.nn import init as init -from colossalai.utils.device import get_current_device from ..base_layer import ParallelLayer from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple @@ -84,7 +84,7 @@ class Linear2p5D(ParallelLayer): self.hidden_size_per_partition = divide(out_features, self.tesseract_dim) # create weight, shape: [k/q, h/q] - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} self.weight = Parameter( torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs) ) @@ -272,7 +272,7 @@ class LayerNorm2p5D(ParallelLayer): self.partitioned_partition = divide(normalized_shape, self.tesseract_dim) # * # create parameters - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs)) if bias: @@ -451,18 +451,24 @@ class PatchEmbedding2p5D(ParallelLayer): self.weight = Parameter( torch.empty( (self.embed_size_per_partition, in_chans, *self.patch_size), - device=get_current_device(), + device=get_accelerator().get_current_device(), dtype=dtype, ) ) - self.bias = Parameter(torch.empty(self.embed_size_per_partition, device=get_current_device(), dtype=dtype)) + self.bias = Parameter( + torch.empty(self.embed_size_per_partition, device=get_accelerator().get_current_device(), dtype=dtype) + ) self.cls_token = Parameter( - torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype) + torch.zeros( + (1, 1, self.embed_size_per_partition), device=get_accelerator().get_current_device(), dtype=dtype + ) ) self.pos_embed = Parameter( torch.zeros( - (1, self.num_patches + 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype + (1, self.num_patches + 1, self.embed_size_per_partition), + device=get_accelerator().get_current_device(), + dtype=dtype, ) ) @@ -632,7 +638,9 @@ class Embedding2p5D(ParallelLayer): self.embed_kwargs = kwargs self.weight = Parameter( - torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype) + torch.empty( + (num_embeddings, embed_dim_per_partition), device=get_accelerator().get_current_device(), dtype=dtype + ) ) self.reset_parameters(weight_initializer) @@ -772,7 +780,7 @@ class VocabParallelEmbedding2p5D(ParallelLayer): self.weight = Parameter( torch.empty( (self.num_embeddings_per_partition, self.embed_dim_per_partition), - device=get_current_device(), + device=get_accelerator().get_current_device(), dtype=dtype, ) ) @@ -910,11 +918,18 @@ class Classifier2p5D(ParallelLayer): self.has_weight = False else: self.weight = Parameter( - torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype) + torch.empty( + self.num_classes, + self.input_size_per_partition, + device=get_accelerator().get_current_device(), + dtype=dtype, + ) ) self.has_weight = True if bias: - self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) + self.bias = Parameter( + torch.zeros(self.num_classes, device=get_accelerator().get_current_device(), dtype=dtype) + ) else: self.bias = None @@ -1068,7 +1083,7 @@ class VocabParallelClassifier2p5D(ParallelLayer): self.hidden_size_per_partition = divide(num_classes, self.tesseract_dim) # create weight, shape: [k/q, h/q] - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} if weight is not None: self.weight = weight self.has_weight = False diff --git a/colossalai/legacy/nn/layer/parallel_3d/layers.py b/colossalai/legacy/nn/layer/parallel_3d/layers.py index 16e515f87..bb01ec851 100644 --- a/colossalai/legacy/nn/layer/parallel_3d/layers.py +++ b/colossalai/legacy/nn/layer/parallel_3d/layers.py @@ -8,6 +8,7 @@ import torch.nn.functional as F from torch import Tensor from torch.nn import Parameter +from colossalai.accelerator import get_accelerator from colossalai.legacy.communication import all_reduce, broadcast from colossalai.legacy.constants import ( INPUT_GROUP_3D, @@ -27,7 +28,6 @@ from colossalai.legacy.utils.checkpointing import ( partition_tensor_parallel_state_dict, ) from colossalai.nn import init as init -from colossalai.utils.device import get_current_device from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple from ._operation import ( @@ -69,11 +69,13 @@ class LayerNorm3D(ParallelLayer): self.normalized_shape_per_partition = divide(normalized_shape, self.depth) self.weight = Parameter( - torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype) + torch.ones(self.normalized_shape_per_partition, device=get_accelerator().get_current_device(), dtype=dtype) ) if bias: self.bias = Parameter( - torch.zeros(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype) + torch.zeros( + self.normalized_shape_per_partition, device=get_accelerator().get_current_device(), dtype=dtype + ) ) else: self.bias = None @@ -202,13 +204,15 @@ class Linear3D(ParallelLayer): torch.empty( self.in_features_per_partition, self.out_features_per_partition, - device=get_current_device(), + device=get_accelerator().get_current_device(), dtype=dtype, ) ) if bias: self.bias = Parameter( - torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype) + torch.zeros( + self.bias_features_per_partition, device=get_accelerator().get_current_device(), dtype=dtype + ) ) else: self.bias = None @@ -380,11 +384,18 @@ class Classifier3D(ParallelLayer): self.has_weight = False else: self.weight = Parameter( - torch.empty(self.num_classes, self.in_features_per_partition, device=get_current_device(), dtype=dtype) + torch.empty( + self.num_classes, + self.in_features_per_partition, + device=get_accelerator().get_current_device(), + dtype=dtype, + ) ) self.has_weight = True if bias: - self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) + self.bias = Parameter( + torch.zeros(self.num_classes, device=get_accelerator().get_current_device(), dtype=dtype) + ) else: self.bias = None @@ -523,14 +534,16 @@ class VocabParallelClassifier3D(ParallelLayer): torch.empty( self.out_features_per_partition, self.in_features_per_partition, - device=get_current_device(), + device=get_accelerator().get_current_device(), dtype=dtype, ) ) self.has_weight = True if bias: self.bias = Parameter( - torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype) + torch.zeros( + self.bias_features_per_partition, device=get_accelerator().get_current_device(), dtype=dtype + ) ) else: self.bias = None @@ -705,16 +718,24 @@ class PatchEmbedding3D(ParallelLayer): self.weight = nn.Parameter( torch.empty( - (embed_size_per_partition, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype + (embed_size_per_partition, in_chans, *self.patch_size), + device=get_accelerator().get_current_device(), + dtype=dtype, ) ) - self.bias = nn.Parameter(torch.empty(embed_size_per_partition, device=get_current_device(), dtype=dtype)) + self.bias = nn.Parameter( + torch.empty(embed_size_per_partition, device=get_accelerator().get_current_device(), dtype=dtype) + ) self.cls_token = nn.Parameter( - torch.zeros((1, 1, embed_size_per_partition), device=get_current_device(), dtype=dtype) + torch.zeros((1, 1, embed_size_per_partition), device=get_accelerator().get_current_device(), dtype=dtype) ) self.pos_embed = nn.Parameter( - torch.zeros((1, self.num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype) + torch.zeros( + (1, self.num_patches + 1, embed_size_per_partition), + device=get_accelerator().get_current_device(), + dtype=dtype, + ) ) self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) @@ -880,7 +901,9 @@ class Embedding3D(ParallelLayer): self.embed_kwargs = kwargs self.weight = nn.Parameter( - torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype) + torch.empty( + (num_embeddings, embed_dim_per_partition), device=get_accelerator().get_current_device(), dtype=dtype + ) ) self.reset_parameters(weight_initializer) @@ -1019,7 +1042,7 @@ class VocabParallelEmbedding3D(ParallelLayer): self.weight = Parameter( torch.empty( (self.num_embeddings_per_partition, self.embed_dim_per_partition), - device=get_current_device(), + device=get_accelerator().get_current_device(), dtype=dtype, ) ) diff --git a/colossalai/legacy/nn/layer/parallel_sequence/_operation.py b/colossalai/legacy/nn/layer/parallel_sequence/_operation.py index 24d5499e3..4e9bf364d 100644 --- a/colossalai/legacy/nn/layer/parallel_sequence/_operation.py +++ b/colossalai/legacy/nn/layer/parallel_sequence/_operation.py @@ -5,11 +5,11 @@ import torch from torch import distributed as dist from torch.cuda.amp import custom_bwd, custom_fwd +from colossalai.accelerator import get_accelerator from colossalai.legacy.communication import ring_forward from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.parallel_sequence._utils import _calc_current_device_range, _calc_incoming_device_range -from colossalai.utils import get_current_device class RingQK(torch.autograd.Function): @@ -30,7 +30,7 @@ class RingQK(torch.autograd.Function): sub_seq_length, sub_seq_length * gpc.get_world_size(ParallelMode.SEQUENCE), dtype=sub_q.dtype, - device=get_current_device(), + device=get_accelerator().get_current_device(), ) # compute local QK^T @@ -71,7 +71,7 @@ class RingQK(torch.autograd.Function): grad_q = torch.zeros_like( sub_q, dtype=sub_q.dtype, - device=get_current_device(), + device=get_accelerator().get_current_device(), ) # compute with local sub_k @@ -105,7 +105,7 @@ class RingAV(torch.autograd.Function): batch_size * num_attention_heads, sub_seq_length, attention_head_size, - device=get_current_device(), + device=get_accelerator().get_current_device(), dtype=attention_score.dtype, ) @@ -142,7 +142,9 @@ class RingAV(torch.autograd.Function): grad_v /= local_world_size # calculate gradient for attention score - grad_attention_score = torch.zeros_like(attention_scores, dtype=grad_output.dtype, device=get_current_device()) + grad_attention_score = torch.zeros_like( + attention_scores, dtype=grad_output.dtype, device=get_accelerator().get_current_device() + ) # compute with local sub_k grad_attention_score[:, :, local_start_idx:local_end_idx] += torch.matmul(grad_output, sub_v.transpose(2, 1)) diff --git a/colossalai/legacy/nn/layer/parallel_sequence/layers.py b/colossalai/legacy/nn/layer/parallel_sequence/layers.py index 063b0cd8e..445b7e4cd 100644 --- a/colossalai/legacy/nn/layer/parallel_sequence/layers.py +++ b/colossalai/legacy/nn/layer/parallel_sequence/layers.py @@ -8,13 +8,12 @@ import torch.nn as nn import torch.nn.functional as F from torch.nn import Parameter -from colossalai.kernel import FusedScaleMaskSoftmax -from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType from colossalai.legacy.context import seed from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.parallel_sequence._operation import RingAV, RingQK from colossalai.legacy.registry import LAYERS +from colossalai.nn.layer.scaled_softmax import AttnMaskType, FusedScaleMaskSoftmax @LAYERS.register_module diff --git a/colossalai/legacy/nn/layer/vanilla/layers.py b/colossalai/legacy/nn/layer/vanilla/layers.py index 590ad5ff6..3a1c2e57b 100644 --- a/colossalai/legacy/nn/layer/vanilla/layers.py +++ b/colossalai/legacy/nn/layer/vanilla/layers.py @@ -7,10 +7,10 @@ from torch import Tensor from torch import nn as nn from torch.nn.parameter import Parameter +from colossalai.accelerator import get_accelerator from colossalai.legacy.context import seed from colossalai.legacy.registry import LAYERS from colossalai.nn import init as init -from colossalai.utils.device import get_current_device from ..utils import to_2tuple @@ -173,12 +173,18 @@ class VanillaPatchEmbedding(nn.Module): self.flatten = flatten self.weight = nn.Parameter( - torch.empty((embed_size, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype) + torch.empty( + (embed_size, in_chans, *self.patch_size), device=get_accelerator().get_current_device(), dtype=dtype + ) + ) + self.bias = nn.Parameter(torch.empty(embed_size, device=get_accelerator().get_current_device(), dtype=dtype)) + self.cls_token = nn.Parameter( + torch.zeros((1, 1, embed_size), device=get_accelerator().get_current_device(), dtype=dtype) ) - self.bias = nn.Parameter(torch.empty(embed_size, device=get_current_device(), dtype=dtype)) - self.cls_token = nn.Parameter(torch.zeros((1, 1, embed_size), device=get_current_device(), dtype=dtype)) self.pos_embed = nn.Parameter( - torch.zeros((1, self.num_patches + 1, embed_size), device=get_current_device(), dtype=dtype) + torch.zeros( + (1, self.num_patches + 1, embed_size), device=get_accelerator().get_current_device(), dtype=dtype + ) ) self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) @@ -242,11 +248,15 @@ class VanillaClassifier(nn.Module): self.has_weight = False else: self.weight = nn.Parameter( - torch.empty(self.num_classes, self.in_features, device=get_current_device(), dtype=dtype) + torch.empty( + self.num_classes, self.in_features, device=get_accelerator().get_current_device(), dtype=dtype + ) ) self.has_weight = True if bias: - self.bias = nn.Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) + self.bias = nn.Parameter( + torch.zeros(self.num_classes, device=get_accelerator().get_current_device(), dtype=dtype) + ) else: self.bias = None @@ -287,7 +297,7 @@ class VanillaLayerNorm(nn.Module): self.normalized_shape = (normalized_shape,) self.variance_epsilon = eps - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} self.weight = nn.Parameter(torch.ones(normalized_shape, **factory_kwargs)) if bias: @@ -333,7 +343,7 @@ class VanillaLinear(nn.Module): self.in_features = in_features self.out_features = out_features self.skip_bias_add = skip_bias_add - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs)) if bias: self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) diff --git a/colossalai/legacy/nn/loss/loss_2d.py b/colossalai/legacy/nn/loss/loss_2d.py index 44f39a6db..474fd4a2c 100644 --- a/colossalai/legacy/nn/loss/loss_2d.py +++ b/colossalai/legacy/nn/loss/loss_2d.py @@ -4,12 +4,12 @@ from torch.cuda.amp import custom_bwd, custom_fwd from torch.nn.functional import cross_entropy from torch.nn.modules.loss import _Loss +from colossalai.accelerator import get_accelerator from colossalai.legacy.context import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d from colossalai.legacy.nn.layer.parallel_2d._utils import assert_summa_initialization from colossalai.legacy.registry import LOSSES -from colossalai.utils import get_current_device @LOSSES.register_module @@ -118,7 +118,7 @@ class _VocabParallelCrossEntropy2D(torch.autograd.Function): grad_2d = grad_input.view(-1, partition_vocab_size) # Add the gradient from matching classes. - arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device()) + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_accelerator().get_current_device()) grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float() # Finally elementwise multiplication with the output gradients. diff --git a/colossalai/legacy/nn/loss/loss_2p5d.py b/colossalai/legacy/nn/loss/loss_2p5d.py index c57bf26e9..b423ab3d8 100644 --- a/colossalai/legacy/nn/loss/loss_2p5d.py +++ b/colossalai/legacy/nn/loss/loss_2p5d.py @@ -4,12 +4,12 @@ from torch.cuda.amp import custom_bwd, custom_fwd from torch.nn.functional import cross_entropy from torch.nn.modules.loss import _Loss +from colossalai.accelerator import get_accelerator from colossalai.legacy.context import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d from colossalai.legacy.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization from colossalai.legacy.registry import LOSSES -from colossalai.utils import get_current_device @LOSSES.register_module @@ -112,7 +112,7 @@ class _VocabParallelCrossEntropy2p5D(torch.autograd.Function): grad_2d = grad_input.view(-1, partition_vocab_size) # Add the gradient from matching classes. - arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device()) + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_accelerator().get_current_device()) grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float() # Finally elementwise multiplication with the output gradients. diff --git a/colossalai/legacy/nn/loss/loss_3d.py b/colossalai/legacy/nn/loss/loss_3d.py index 988317cae..de6a674d6 100644 --- a/colossalai/legacy/nn/loss/loss_3d.py +++ b/colossalai/legacy/nn/loss/loss_3d.py @@ -4,12 +4,12 @@ from torch.cuda.amp import custom_bwd, custom_fwd from torch.nn.functional import cross_entropy from torch.nn.modules.loss import _Loss +from colossalai.accelerator import get_accelerator from colossalai.legacy.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d from colossalai.legacy.nn.layer.parallel_3d._utils import get_parallel_mode_from_env from colossalai.legacy.registry import LOSSES -from colossalai.utils import get_current_device @LOSSES.register_module @@ -80,7 +80,7 @@ class _VocabParallelCrossEntropy3D(torch.autograd.Function): target_mask = (targets < vocab_start) | (targets > vocab_end) masked_target = targets.clone() - vocab_start masked_target[target_mask] = 0 - arange_1d = torch.arange(start=0, end=logits.size()[0], device=get_current_device()) + arange_1d = torch.arange(start=0, end=logits.size()[0], device=get_accelerator().get_current_device()) predicted_logits = logits[arange_1d, masked_target] predicted_logits = predicted_logits.clone().contiguous().view_as(targets) predicted_logits[target_mask] = 0.0 @@ -110,7 +110,7 @@ class _VocabParallelCrossEntropy3D(torch.autograd.Function): grad_2d = input_grad.view(-1, partition_vocab_size) # Add the gradient from matching classes. - arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device()) + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_accelerator().get_current_device()) grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float() input_grad.mul_(output_grad.unsqueeze(dim=-1)) diff --git a/colossalai/legacy/trainer/hooks/_metric_hook.py b/colossalai/legacy/trainer/hooks/_metric_hook.py index 35a7f0a15..0e6731db5 100644 --- a/colossalai/legacy/trainer/hooks/_metric_hook.py +++ b/colossalai/legacy/trainer/hooks/_metric_hook.py @@ -7,12 +7,12 @@ from typing import Callable import torch import torch.distributed as dist +from colossalai.accelerator import get_accelerator from colossalai.legacy.communication import all_reduce from colossalai.legacy.context import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.registry import HOOKS from colossalai.legacy.utils import is_no_pp_or_last_stage -from colossalai.utils import get_current_device from ._base_hook import BaseHook from ._commons_ import _format_number @@ -82,8 +82,8 @@ class LossMetric(Metric): def __init__(self, epoch_only): super().__init__(epoch_only=epoch_only) - self.last_step_loss = torch.zeros(1, device=get_current_device()) - self.accum_loss = torch.zeros(1, device=get_current_device()) + self.last_step_loss = torch.zeros(1, device=get_accelerator().get_current_device()) + self.accum_loss = torch.zeros(1, device=get_accelerator().get_current_device()) self.count = 0 def reset(self) -> None: @@ -164,10 +164,10 @@ class AccuracyMetric(Metric): def __init__(self, epoch_only: bool, accuracy_func: Callable): super().__init__(epoch_only=epoch_only) self.acc = accuracy_func - self.last_step_sum = torch.zeros(1, device=get_current_device()) - self.last_step_correct = torch.zeros(1, device=get_current_device()) - self.accumulated_sum = torch.zeros(1, device=get_current_device()) - self.accumulated_correct = torch.zeros(1, device=get_current_device()) + self.last_step_sum = torch.zeros(1, device=get_accelerator().get_current_device()) + self.last_step_correct = torch.zeros(1, device=get_accelerator().get_current_device()) + self.accumulated_sum = torch.zeros(1, device=get_accelerator().get_current_device()) + self.accumulated_correct = torch.zeros(1, device=get_accelerator().get_current_device()) def reset(self) -> None: self.last_step_sum.zero_() @@ -320,10 +320,10 @@ class ThroughputMetric(Metric): super().__init__(epoch_only=epoch_only) self.ignored_steps = ignored_steps self.cur_steps = 0 - self.accumulated_num_samples = torch.zeros(1, device=get_current_device()) - self.accumulated_used_time = torch.zeros(1, device=get_current_device()) - self.last_step_num_samples = torch.zeros(1, device=get_current_device()) - self.last_step_used_time = torch.zeros(1, device=get_current_device()) + self.accumulated_num_samples = torch.zeros(1, device=get_accelerator().get_current_device()) + self.accumulated_used_time = torch.zeros(1, device=get_accelerator().get_current_device()) + self.last_step_num_samples = torch.zeros(1, device=get_accelerator().get_current_device()) + self.last_step_used_time = torch.zeros(1, device=get_accelerator().get_current_device()) self._tflop_per_step = tflop_per_step self._use_local = use_local diff --git a/colossalai/legacy/utils/activation_checkpoint.py b/colossalai/legacy/utils/activation_checkpoint.py index 9a8051ae9..d1382cb1e 100644 --- a/colossalai/legacy/utils/activation_checkpoint.py +++ b/colossalai/legacy/utils/activation_checkpoint.py @@ -6,8 +6,8 @@ import weakref import torch from torch.utils.checkpoint import check_backward_validity, detach_variable +from colossalai.accelerator import get_accelerator from colossalai.legacy.context.random import get_current_mode, get_states, set_mode, set_seed_states, sync_states -from colossalai.utils.device import autocast, get_current_device def copy_to_device(obj, device): @@ -33,7 +33,7 @@ class CheckpointFunction(torch.autograd.Function): check_backward_validity(args) ctx.run_function = run_function ctx.activation_offload = activation_offload - ctx.device = get_current_device() + ctx.device = get_accelerator().get_current_device() # preserve rng states ctx.fwd_cpu_rng_state = torch.get_rng_state() @@ -110,7 +110,7 @@ class CheckpointFunction(torch.autograd.Function): inputs[idx] = tensors[i] detached_inputs = detach_variable(tuple(inputs)) if ctx.had_autocast_in_fwd: - with torch.enable_grad(), autocast(): + with torch.enable_grad(), get_accelerator().autocast()(): outputs = ctx.run_function(*detached_inputs) else: with torch.enable_grad(): @@ -226,7 +226,7 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args): # rerun forward, the inner_pack will store all the activations in storage if has_autocast_in_fwd: - with torch.enable_grad(), autocast(), torch.autograd.graph.saved_tensors_hooks( + with torch.enable_grad(), get_accelerator().autocast()(), torch.autograd.graph.saved_tensors_hooks( inner_pack, inner_unpack ): _unused = function(*args) @@ -245,7 +245,7 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args): # get device if we need to offload the activation if activation_offload: - device = get_current_device() + device = get_accelerator().get_current_device() # run function with pack and unpack as saved_tensors_hooks with torch.autograd.graph.saved_tensors_hooks(pack, unpack): diff --git a/colossalai/legacy/utils/common.py b/colossalai/legacy/utils/common.py index 671bcc3d6..76ec08e96 100644 --- a/colossalai/legacy/utils/common.py +++ b/colossalai/legacy/utils/common.py @@ -96,9 +96,9 @@ def _calc_l2_norm(grads): global fused_optim if fused_optim is None: - from colossalai.kernel.op_builder import FusedOptimBuilder + from colossalai.kernel.kernel_loader import FusedOptimizerLoader - fused_optim = FusedOptimBuilder().load() + fused_optim = FusedOptimizerLoader().load() norm = 0.0 if len(grads) > 0: diff --git a/colossalai/legacy/utils/memory.py b/colossalai/legacy/utils/memory.py index 2f99a7d2f..cfb22d315 100644 --- a/colossalai/legacy/utils/memory.py +++ b/colossalai/legacy/utils/memory.py @@ -6,9 +6,9 @@ import torch import torch.distributed as dist from packaging import version +from colossalai.accelerator import get_accelerator from colossalai.legacy.core import global_context as gpc from colossalai.logging import get_dist_logger -from colossalai.utils import get_current_device _GLOBAL_CUDA_MEM_FRACTION = 1.0 _GLOBAL_CPU_MEM_CAPACITY = -1 @@ -112,7 +112,10 @@ def colo_device_memory_capacity(device: torch.device) -> int: # In the context of 1-CPU-N-GPU, the memory capacity of the current process is 1/N overall CPU memory. return colo_get_cpu_memory_capacity() / gpc.num_processes_on_current_node if device.type == "cuda": - return torch.cuda.get_device_properties(get_current_device()).total_memory * _GLOBAL_CUDA_MEM_FRACTION + return ( + torch.cuda.get_device_properties(get_accelerator().get_current_device()).total_memory + * _GLOBAL_CUDA_MEM_FRACTION + ) def colo_device_memory_used(device: torch.device) -> int: @@ -153,7 +156,7 @@ def colo_set_process_memory_fraction(ratio: float) -> None: return global _GLOBAL_CUDA_MEM_FRACTION _GLOBAL_CUDA_MEM_FRACTION = ratio - torch.cuda.set_per_process_memory_fraction(_GLOBAL_CUDA_MEM_FRACTION, get_current_device()) + torch.cuda.set_per_process_memory_fraction(_GLOBAL_CUDA_MEM_FRACTION, get_accelerator().get_current_device()) def colo_set_cpu_memory_capacity(size: int) -> None: diff --git a/colossalai/legacy/utils/profiler/legacy/comm_profiler.py b/colossalai/legacy/utils/profiler/legacy/comm_profiler.py index ad54b989f..a9e3ffe1a 100644 --- a/colossalai/legacy/utils/profiler/legacy/comm_profiler.py +++ b/colossalai/legacy/utils/profiler/legacy/comm_profiler.py @@ -8,7 +8,7 @@ import torch.distributed as dist from torch.autograd.profiler import profile from torch.distributed import ReduceOp -from colossalai.utils import get_current_device +from colossalai.accelerator import get_accelerator from .prof_utils import BaseProfiler, _format_bandwidth, _format_memory, _format_time @@ -177,7 +177,7 @@ class CommProfiler(BaseProfiler): assert current_comm_event is not None, "dist op has not been found" - buffer = torch.tensor([current_comm_event.self_cuda_time], device=get_current_device()) + buffer = torch.tensor([current_comm_event.self_cuda_time], device=get_accelerator().get_current_device()) torch_all_reduce(buffer, op=ReduceOp.MIN, group=group) current_comm_event.self_cuda_time = buffer.item() diff --git a/colossalai/legacy/zero/gemini/stateful_tensor_mgr.py b/colossalai/legacy/zero/gemini/stateful_tensor_mgr.py index e336717f4..b0360880e 100644 --- a/colossalai/legacy/zero/gemini/stateful_tensor_mgr.py +++ b/colossalai/legacy/zero/gemini/stateful_tensor_mgr.py @@ -3,7 +3,7 @@ import types from time import time from typing import List -from colossalai.utils.device import get_current_device +from colossalai.accelerator import get_accelerator from .stateful_tensor import StatefulTensor, TensorState from .tensor_placement_policy import TensorPlacementPolicy @@ -69,7 +69,7 @@ class StatefulTensorMgr(object): # move COMPUTE tensors to CUDA self._cpu_gpu_move_volume += cuda_demand for t in move_to_cuda_tensor_list: - colo_model_data_tensor_move_inline(t, get_current_device()) + colo_model_data_tensor_move_inline(t, get_accelerator().get_current_device()) @property def cpu_gpu_move_volume(self): diff --git a/colossalai/legacy/zero/gemini/tensor_placement_policy.py b/colossalai/legacy/zero/gemini/tensor_placement_policy.py index 3aca80cfe..6fde91d4a 100644 --- a/colossalai/legacy/zero/gemini/tensor_placement_policy.py +++ b/colossalai/legacy/zero/gemini/tensor_placement_policy.py @@ -5,8 +5,8 @@ from typing import List, Optional, Type import torch +from colossalai.accelerator import get_accelerator from colossalai.legacy.utils.memory import colo_device_memory_capacity -from colossalai.utils import get_current_device from colossalai.zero.gemini.memory_tracer import MemStatsCollector from .stateful_tensor import StatefulTensor @@ -38,7 +38,7 @@ class CPUTensorPlacementPolicy(TensorPlacementPolicy): class CUDATensorPlacementPolicy(TensorPlacementPolicy): def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None: assert torch.cuda.is_available(), "Cannot use CUDATensorPlacementPolicy when CUDA is not available" - super().__init__(get_current_device(), mem_stats_collector=mem_stats_collector) + super().__init__(get_accelerator().get_current_device(), mem_stats_collector=mem_stats_collector) def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> int: return 0, 0 @@ -78,7 +78,7 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy): int: the volume of memory that is evicted """ start = time() - cuda_capacity = colo_device_memory_capacity(get_current_device()) + cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device()) used_cuda_model_data = StatefulTensor.GST_MGR.total_mem["cuda"] if warmup: # We designate a part of CUDA memory for model data in warmup iterations. diff --git a/colossalai/legacy/zero/shard_utils/bucket_tensor_shard_strategy.py b/colossalai/legacy/zero/shard_utils/bucket_tensor_shard_strategy.py index b9d3071a8..e5a35dea1 100644 --- a/colossalai/legacy/zero/shard_utils/bucket_tensor_shard_strategy.py +++ b/colossalai/legacy/zero/shard_utils/bucket_tensor_shard_strategy.py @@ -4,8 +4,8 @@ import torch import torch.distributed as dist from torch._utils import _flatten_dense_tensors as flatten +from colossalai.accelerator import get_accelerator from colossalai.legacy.zero.sharded_param.sharded_tensor import ShardedTensor -from colossalai.utils import get_current_device from .tensor_shard_strategy import TensorShardStrategy @@ -30,9 +30,11 @@ class BucketTensorShardStrategy(TensorShardStrategy): rank = dist.get_rank(process_group) for i in range(world_size): if i == rank: - buffer_list.append(flatten([t.payload for t in tensor_list]).cuda(get_current_device())) + buffer_list.append( + flatten([t.payload for t in tensor_list]).cuda(get_accelerator().get_current_device()) + ) else: - buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_current_device())) + buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_accelerator().get_current_device())) dist.all_gather(buffer_list, buffer_list[rank], group=process_group) # Move to target device before splitting buffer # Ensure we utilize maximum PCIE bandwidth diff --git a/colossalai/legacy/zero/shard_utils/tensor_shard_strategy.py b/colossalai/legacy/zero/shard_utils/tensor_shard_strategy.py index ebaef774b..fb6ef534b 100644 --- a/colossalai/legacy/zero/shard_utils/tensor_shard_strategy.py +++ b/colossalai/legacy/zero/shard_utils/tensor_shard_strategy.py @@ -3,11 +3,11 @@ from typing import List, Optional import torch import torch.distributed as dist +from colossalai.accelerator import get_accelerator from colossalai.legacy.zero.gemini.tensor_utils import colo_model_data_tensor_move_inline from colossalai.legacy.zero.shard_utils import BaseShardStrategy from colossalai.legacy.zero.shard_utils.commons import get_shard from colossalai.legacy.zero.sharded_param.sharded_tensor import ShardedTensor -from colossalai.utils import get_current_device class TensorShardStrategy(BaseShardStrategy): @@ -34,9 +34,9 @@ class TensorShardStrategy(BaseShardStrategy): if t.is_sharded: return if t.payload.device.type == "cuda": - assert t.payload.device == get_current_device(), ( + assert t.payload.device == get_accelerator().get_current_device(), ( f"shard tensor on cuda device index {t.payload.device.index}," - f" but current cuda device is {get_current_device()}" + f" but current cuda device is {get_accelerator().get_current_device()}" ) sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group)) t.payload_reset(sharded_payload) @@ -50,7 +50,9 @@ class TensorShardStrategy(BaseShardStrategy): world_size = dist.get_world_size(process_group) rank = dist.get_rank(process_group) - buffer = torch.empty(payload_numel * world_size, dtype=t.payload.dtype, device=get_current_device()) + buffer = torch.empty( + payload_numel * world_size, dtype=t.payload.dtype, device=get_accelerator().get_current_device() + ) buffer_list = list(torch.chunk(buffer, chunks=world_size, dim=0)) buffer_list[rank].copy_(t.payload) diff --git a/colossalai/legacy/zero/sharded_model/sharded_model_v2.py b/colossalai/legacy/zero/sharded_model/sharded_model_v2.py index 85f2ac215..bb7744a80 100644 --- a/colossalai/legacy/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/legacy/zero/sharded_model/sharded_model_v2.py @@ -10,6 +10,7 @@ import torch.nn as nn from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter +from colossalai.accelerator import get_accelerator from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.utils.memory import colo_device_memory_capacity @@ -22,7 +23,7 @@ from colossalai.legacy.zero.gemini.tensor_utils import colo_model_data_move_to_c from colossalai.legacy.zero.shard_utils import BaseShardStrategy from colossalai.legacy.zero.sharded_model.reduce_scatter import ReduceScatterBucketer from colossalai.logging import get_dist_logger -from colossalai.utils import disposable, get_current_device +from colossalai.utils import disposable from colossalai.zero.gemini.memory_tracer import MemStatsCollector from ._utils import ( @@ -212,8 +213,12 @@ class ShardedModelV2(nn.Module): self.logger.error(f"dump memory tracer collected information to a {filename}", ranks=[0]) if gpc.get_global_rank() == 0: with open(filename, "w+") as f: - f.write(f"cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n") - f.write(f"cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n") + f.write( + f"cuda reserved {torch.cuda.memory_reserved(get_accelerator().get_current_device()) / 1e9} GB\n" + ) + f.write( + f"cuda max allocated {torch.cuda.max_memory_allocated(get_accelerator().get_current_device()) / 1e9} GB\n" + ) f.write("CUDA model data (GB)\n") f.write("\n") f.write("CUDA non model data (GB)\n") @@ -266,7 +271,8 @@ class ShardedModelV2(nn.Module): # model data is fixed in cuda during training. # cuda margin space can be used to store OS. self._cuda_margin_space = ( - colo_device_memory_capacity(get_current_device()) - self._memstats_collector._memstats.max_overall_cuda + colo_device_memory_capacity(get_accelerator().get_current_device()) + - self._memstats_collector._memstats.max_overall_cuda ) @torch.no_grad() diff --git a/colossalai/legacy/zero/sharded_model/zero_hook.py b/colossalai/legacy/zero/sharded_model/zero_hook.py index 892e9f31d..332f44d53 100644 --- a/colossalai/legacy/zero/sharded_model/zero_hook.py +++ b/colossalai/legacy/zero/sharded_model/zero_hook.py @@ -3,13 +3,13 @@ from typing import Optional import torch import torch.distributed as dist +from colossalai.accelerator import get_accelerator from colossalai.legacy.registry import OPHOOKS from colossalai.legacy.zero.gemini.ophooks import BaseOpHook from colossalai.legacy.zero.gemini.stateful_tensor import TensorState from colossalai.legacy.zero.gemini.stateful_tensor_mgr import StatefulTensorMgr from colossalai.legacy.zero.shard_utils import BaseShardStrategy from colossalai.logging import get_dist_logger -from colossalai.utils import get_current_device from colossalai.zero.gemini.memory_tracer import MemStatsCollector @@ -33,7 +33,7 @@ class ZeroHook(BaseOpHook): self.process_group = process_group # NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU - self.computing_device = get_current_device() + self.computing_device = get_accelerator().get_current_device() self._memstarts_collector = memstarts_collector self._stateful_tensor_mgr = stateful_tensor_mgr diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py index c71e6c1f4..34342436f 100644 --- a/colossalai/moe/_operation.py +++ b/colossalai/moe/_operation.py @@ -11,9 +11,9 @@ MOE_KERNEL = None def load_moe(): global MOE_KERNEL - from colossalai.kernel.op_builder import MOEBuilder + from colossalai.kernel.kernel_loader import MoeLoader - MOE_KERNEL = MOEBuilder().load() + MOE_KERNEL = MoeLoader().load() class AllGather(torch.autograd.Function): @@ -145,14 +145,8 @@ class AllToAll(torch.autograd.Function): class HierarchicalAllToAll(torch.autograd.Function): - @staticmethod - def forward( - ctx: Any, - inputs: Tensor, - groups: Tuple[ProcessGroup, ProcessGroup], - src_rank: int - ) -> Tensor: + def forward(ctx: Any, inputs: Tensor, groups: Tuple[ProcessGroup, ProcessGroup], src_rank: int) -> Tensor: """ Returns: outputs: Tensor @@ -276,8 +270,9 @@ class MoeCombine(torch.autograd.Function): if tokens_grad.dtype != torch.float32: tokens_grad = tokens_grad.to(torch.float32) - d_expert, d_logits = MOE_KERNEL.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, tokens_grad, expert_tokens, logits, - mask, dest_idx) + d_expert, d_logits = MOE_KERNEL.combine_backward( + ctx.s, ctx.e, ctx.c, ctx.h, tokens_grad, expert_tokens, logits, mask, dest_idx + ) if d_expert.dtype != ctx.dtype: d_expert = d_expert.to(ctx.dtype) diff --git a/colossalai/moe/manager.py b/colossalai/moe/manager.py index 3e64d796c..eaca75b8f 100644 --- a/colossalai/moe/manager.py +++ b/colossalai/moe/manager.py @@ -69,7 +69,7 @@ class MoEManager(metaclass=SingletonMeta): fixed_dp_size (int, optional): Fixed dp size in fixed mode. Defaults to 0. fixed_ep_size (int, optional): Fixed ep size in fixed mode. Defaults to 0. fixed_pp_size (int, optional): Fixed pp size in fixed mode. Defaults to 0. - use_ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if Fasle. Defaults to True. + use_ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if False. Defaults to True. """ assert not self.is_initialized, "MoE distributed context shouldn't be set up again" assert torch.cuda.is_available(), "MoE requires to enable CUDA first" diff --git a/colossalai/moe/routers.py b/colossalai/moe/routers.py index c5bb50862..f5815d05d 100644 --- a/colossalai/moe/routers.py +++ b/colossalai/moe/routers.py @@ -8,9 +8,9 @@ import torch.nn as nn import torch.nn.functional as F from torch.distributed import ProcessGroup +from colossalai.accelerator import get_accelerator from colossalai.moe._operation import moe_cumsum from colossalai.moe.manager import MOE_MANAGER -from colossalai.utils import get_current_device class MoeRouter(nn.Module, ABC): @@ -24,14 +24,16 @@ class MoeRouter(nn.Module, ABC): drop_tks (bool, optional): Whether drops tokens in evaluation """ - def __init__(self, - k_value: int, - capacity_factor_train: float, - capacity_factor_eval: float, - min_capacity: int, - noisy_func: Optional[Callable] = None, - drop_tks: bool = True, - use_kernel: bool = False): + def __init__( + self, + k_value: int, + capacity_factor_train: float, + capacity_factor_eval: float, + min_capacity: int, + noisy_func: Optional[Callable] = None, + drop_tks: bool = True, + use_kernel: bool = False, + ): super().__init__() self.k_value = k_value self.capacity_factor_train = capacity_factor_train @@ -68,8 +70,9 @@ class MoeRouter(nn.Module, ABC): if router_probs.dim() == expert_indices.dim() == 2: router_probs = router_probs.unsqueeze(0) expert_indices = expert_indices.unsqueeze(0) - assert router_probs.dim() == expert_indices.dim() == 3, \ - "router_probs must be 3D tensor and expert_indices must be 4D tensor" + assert ( + router_probs.dim() == expert_indices.dim() == 3 + ), "router_probs must be 3D tensor and expert_indices must be 4D tensor" # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts]. expert_mask = F.one_hot(expert_indices, num_experts) @@ -122,25 +125,29 @@ class Top1Router(MoeRouter): drop_tks (bool, optional): Whether drops tokens in evaluation """ - def __init__(self, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - select_policy: str = "first", - noisy_func: Optional[Callable] = None, - drop_tks: bool = True): - super().__init__(k_value=1, - capacity_factor_train=capacity_factor_train, - capacity_factor_eval=capacity_factor_eval, - min_capacity=min_capacity, - noisy_func=noisy_func, - drop_tks=drop_tks) + def __init__( + self, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + select_policy: str = "first", + noisy_func: Optional[Callable] = None, + drop_tks: bool = True, + ): + super().__init__( + k_value=1, + capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + min_capacity=min_capacity, + noisy_func=noisy_func, + drop_tks=drop_tks, + ) self.select_policy = select_policy assert select_policy in {"first", "random"} if select_policy == "random": self.uniform = torch.distributions.uniform.Uniform( - low=torch.tensor(0.0, device=get_current_device()), - high=torch.tensor(1.0, device=get_current_device()) + low=torch.tensor(0.0, device=get_accelerator().get_current_device()), + high=torch.tensor(1.0, device=get_accelerator().get_current_device()), ).rsample def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple: @@ -216,18 +223,22 @@ class Top2Router(MoeRouter): drop_tks (bool, optional): Whether drops tokens in evaluation. """ - def __init__(self, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - noisy_func: Optional[Callable] = None, - drop_tks: bool = True): - super().__init__(k_value=2, - capacity_factor_train=capacity_factor_train, - capacity_factor_eval=capacity_factor_eval, - min_capacity=min_capacity, - noisy_func=noisy_func, - drop_tks=drop_tks) + def __init__( + self, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + noisy_func: Optional[Callable] = None, + drop_tks: bool = True, + ): + super().__init__( + k_value=2, + capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + min_capacity=min_capacity, + noisy_func=noisy_func, + drop_tks=drop_tks, + ) def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple: """ @@ -255,8 +266,8 @@ class Top2Router(MoeRouter): top2_idx = torch.argmax(logits_except1, dim=-1) mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32) - cmask = (mask1 + mask2) # loss: [s, e] - cmask = cmask.float() / 2.0 # div 2 to normalize it to 1 + cmask = mask1 + mask2 # loss: [s, e] + cmask = cmask.float() / 2.0 # div 2 to normalize it to 1 # calculate loss expert_indices = torch.stack([top1_idx, top2_idx], dim=-1) @@ -269,7 +280,7 @@ class Top2Router(MoeRouter): dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) capacity = max_num.item() - rank1 = moe_cumsum(mask1, use_kernel=self.use_kernel) # rank1: [s, e] + rank1 = moe_cumsum(mask1, use_kernel=self.use_kernel) # rank1: [s, e] rank2 = moe_cumsum(mask2, use_kernel=self.use_kernel) rank2 += torch.sum(mask1, dim=-2, keepdim=True) @@ -336,15 +347,18 @@ class TopKRouter(MoeRouter): oversubscribed / reach capacity. """ - def __init__(self, - num_selected_experts: int, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - noisy_func: Optional[Callable] = None, - drop_tks: bool = True): - super().__init__(num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func, - drop_tks) + def __init__( + self, + num_selected_experts: int, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + noisy_func: Optional[Callable] = None, + drop_tks: bool = True, + ): + super().__init__( + num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func, drop_tks + ) def forward( self, @@ -410,7 +424,7 @@ class TopKRouter(MoeRouter): # The combine array will be used for combining expert outputs, scaled by the # router probabilities. Shape: [num_groups, tokens_per_group, num_experts, # expert_capacity]. - combine_array = torch.einsum('...te,...tec->...tec', router_probs, dispatch_mask) + combine_array = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask) return combine_array, dispatch_mask diff --git a/colossalai/moe/utils.py b/colossalai/moe/utils.py index 5a17a6e0d..e25e7dd48 100644 --- a/colossalai/moe/utils.py +++ b/colossalai/moe/utils.py @@ -7,13 +7,12 @@ import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F +from colossalai.accelerator import get_accelerator from colossalai.moe.manager import MOE_MANAGER from colossalai.tensor.moe_tensor.api import get_dp_group, get_dp_group_ranks, get_ep_size, is_moe_tensor -from colossalai.utils import get_current_device class ForceFP32Parameter(torch.nn.Parameter): - def half(self, memory_format=None): return self.data.clone() @@ -30,8 +29,8 @@ class NormalNoiseGenerator: def __init__(self, num_experts: int): self.normal = torch.distributions.normal.Normal( - loc=torch.tensor(0.0, device=get_current_device()), - scale=torch.tensor(1.0 / num_experts**2, device=get_current_device()), + loc=torch.tensor(0.0, device=get_accelerator().get_current_device()), + scale=torch.tensor(1.0 / num_experts**2, device=get_accelerator().get_current_device()), ).rsample def __call__(self, inputs: torch.Tensor): @@ -52,8 +51,8 @@ class UniformNoiseGenerator: def __init__(self, eps: float = 1e-2): self.uniform = torch.distributions.uniform.Uniform( - low=torch.tensor(1.0 - eps, device=get_current_device()), - high=torch.tensor(1.0 + eps, device=get_current_device()), + low=torch.tensor(1.0 - eps, device=get_accelerator().get_current_device()), + high=torch.tensor(1.0 + eps, device=get_accelerator().get_current_device()), ).rsample def __call__(self, inputs: torch.Tensor): @@ -142,7 +141,7 @@ def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]] epsize_param_dict = dict() for param in model.parameters(): if not is_moe_tensor(param): - ep_size = 1 # set ep_size to 1 for dp parameters + ep_size = 1 # set ep_size to 1 for dp parameters else: ep_size = get_ep_size(param) if ep_size not in epsize_param_dict: @@ -193,18 +192,13 @@ def create_ep_hierarchical_group( assert nproc_per_node is not None, "Please use torchrun to launch the job, or specify nproc_per_node manually." nproc_per_node = int(nproc_per_node) else: - assert dist.get_world_size() % nproc_per_node == 0, \ - "nproc_per_node should be a divisor of world_size." + assert dist.get_world_size() % nproc_per_node == 0, "nproc_per_node should be a divisor of world_size." num_node = dist.get_world_size() // nproc_per_node intra_src_rank = None ep_intra_node_group = None for i in range(num_node): - ep_intra_ranks = [ - i * nproc_per_node + j - for j in range(nproc_per_node) - if j in ep_group_ranks - ] + ep_intra_ranks = [i * nproc_per_node + j for j in range(nproc_per_node) if j in ep_group_ranks] group = dist.new_group(ep_intra_ranks) if rank in ep_intra_ranks: assert ep_intra_node_group is None @@ -212,10 +206,7 @@ def create_ep_hierarchical_group( intra_src_rank = ep_intra_ranks[0] ep_inter_node_group = None - ep_inter_ranks = [ - ep_group_ranks[0] + i * nproc_per_node - for i in range(num_node) - ] + ep_inter_ranks = [ep_group_ranks[0] + i * nproc_per_node for i in range(num_node)] if len(ep_inter_ranks) > 1: group = dist.new_group(ep_inter_ranks) if rank in ep_inter_ranks: diff --git a/colossalai/nn/layer/colo_attention.py b/colossalai/nn/layer/colo_attention.py new file mode 100644 index 000000000..0b7011e8e --- /dev/null +++ b/colossalai/nn/layer/colo_attention.py @@ -0,0 +1,209 @@ +import enum +import math +import warnings +from dataclasses import dataclass +from typing import Iterable, Optional, Tuple + +import torch +import torch.nn.functional as F +from einops import rearrange + +from colossalai.accelerator import get_accelerator +from colossalai.kernel.kernel_loader import FlashAttentionLoader + + +@dataclass +class SeqLenInfo: + seqlens: Iterable[int] = None + indices: torch.Tensor = None + max_seqlen: int = None + cu_seqlens: torch.Tensor = None + + @staticmethod + def materialize( + attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_accelerator().get_current_device() + ): + if attn_mask is not None: + indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device) + seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten() + else: + batch_size, tgt_len = size[0], size[1] + indices = torch.arange(batch_size * tgt_len, dtype=torch.long, device=device) + seqlens = torch.LongTensor([tgt_len] * batch_size, device=device) + max_seqlen = max(seqlens) + cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device) + return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens) + + +class AttnMaskType(enum.Enum): + padding = 1 + causal = 2 + paddedcausal = 3 + + +class Unpad(torch.autograd.Function): + """ + Adapted from + https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py + """ + + @staticmethod + def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor): + ctx.save_for_backward(indices) + # [b, s, ...] + assert tensor.ndim >= 3 + ctx.bsz = tensor.shape[0] + out = rearrange(tensor, "b s ... -> (b s) ...") + ctx.shape = out.shape + # [ntokens, ...] + return out[indices] + + @staticmethod + def backward(ctx, grad_output): + (indices,) = ctx.saved_tensors + # [ntokens, ...] + grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device) + grad[indices] = grad_output + grad = rearrange(grad, "(b s) ... -> b s ...", b=ctx.bsz) + # [b, s, ...] + return grad, None + + +class Repad(torch.autograd.Function): + """ + Adapted from + https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py + """ + + @staticmethod + def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int): + ctx.save_for_backward(indices) + # [ntokens, ...] + tensor = tensor + out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device) + # [b*s, ...] + out[indices] = tensor + return out + + @staticmethod + def backward(ctx, grad_output): + (indices,) = ctx.saved_tensors + # [b*s, ...] + grad = grad_output[indices] + # [ntokens, ...] + return grad, None, None, None + + +class ColoAttention(torch.nn.Module): + def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale=None): + super().__init__() + assert ( + embed_dim % num_heads == 0 + ), f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})." + if scale is not None: + self.scale = scale + else: + self.scale = 1 / math.sqrt(embed_dim // num_heads) + self.dropout = dropout + + self.attn = FlashAttentionLoader().load() + + @staticmethod + def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: + return Unpad.apply(tensor, indices) + + @staticmethod + def repad(tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor: + return Repad.apply(tensor, indices, batch_size, seq_len) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + origin_attn_mask: Optional[torch.Tensor] = None, + attn_mask_type: Optional[AttnMaskType] = None, + bias: Optional[torch.Tensor] = None, + ): + """ + ColoAttention + + Args: + q: (batch, q_seqlen, nheads, headdim) + k: (batch, kv_seqlen, nheads, headdim) + v: (batch, kv_seqlen, nheads, headdim) + origin_attn_mask: (nheads, q_seqlen, kv_seqlen) + bias: will not be used + Return: + attn_out: (batch, q_seqlen, nheads, headdim). + """ + # if flash attention is not applicable, switch to memory effcient attention + if self.attn.__name__ == "flash_attention" and ( + query.dtype not in [torch.float16, torch.bfloat16] or bias != None + ): + warnings.warn( + f"flash-attn expects fp16 or bf16 but got {query.dtype}, switching to xformers' implementation." + ) + self.attn = FlashAttentionLoader().load(ext_name="flash_attention_xformers_cuda") + + padded = attn_mask_type is not None and attn_mask_type.value % 2 == 1 + causal = attn_mask_type is not None and attn_mask_type.value > 1 + + batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1] + # unpad + seq_len_info_q = None + seq_len_info_kv = None + if padded: + # bert style, unpad process + assert ( + attn_mask is not None + ), f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}." + assert attn_mask.dim() == 2, ( + "attention mask is supposed to have shape (batch_size, seq_len), " + + f"but got {attn_mask.dim()} dimensions." + ) + + # bert style + if tgt_len == src_len: + seq_len_info_q = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device) + if batch_size > 1: + query, key, value = self.unpad( + torch.stack([query, key, value], dim=2), seq_len_info_q.indices + ).unbind(dim=1) + else: + query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1) + seq_len_info_kv = seq_len_info_q + else: + seq_len_info_q = SeqLenInfo.materialize(size=(batch_size, tgt_len), device=query.device) + seq_len_info_kv = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device) + if batch_size > 1: + query = rearrange(query, "b s ... -> c (b s) ...", c=1) + key, value = self.unpad(torch.stack([query, key, value], dim=2), seq_len_info_kv.indices).unbind( + dim=1 + ) + else: + query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1) + + out = self.attn( + query, + key, + value, + seq_len_info_q=seq_len_info_q, + seq_len_info_kv=seq_len_info_kv, + origin_attn_mask=origin_attn_mask, + dropout_p=self.dropout, + scale=self.scale, + causal=causal, + padded=padded, + ) + + # repad + if padded: + if batch_size > 1: + out = self.repad(out, seq_len_info_q.indices, batch_size, tgt_len) + out = rearrange(out, "(b s) h d -> b s h d", b=batch_size) + + if len(out.shape) == 4: + out = rearrange(out, "b s h d -> b s (h d)") + return out diff --git a/colossalai/kernel/cuda_native/layer_norm.py b/colossalai/nn/layer/layernorm.py similarity index 95% rename from colossalai/kernel/cuda_native/layer_norm.py rename to colossalai/nn/layer/layernorm.py index c7d2a3a45..1db48faee 100644 --- a/colossalai/kernel/cuda_native/layer_norm.py +++ b/colossalai/nn/layer/layernorm.py @@ -9,7 +9,7 @@ from torch.cuda.amp import custom_bwd, custom_fwd from torch.nn import init from torch.nn.parameter import Parameter -from colossalai.kernel.op_builder.layernorm import LayerNormBuilder +from colossalai.kernel.kernel_loader import LayerNormLoader try: from colossalai._C import layer_norm @@ -29,7 +29,7 @@ class FusedLayerNormAffineFunction(torch.autograd.Function): global layer_norm if layer_norm is None: - layer_norm = LayerNormBuilder().load() + layer_norm = LayerNormLoader().load() output, mean, invvar = layer_norm.forward_affine(input_, ctx.normalized_shape, weight_, bias_, ctx.eps) ctx.layernorm_op = layer_norm ctx.save_for_backward(input_, weight_, bias_, mean, invvar) diff --git a/colossalai/nn/layer/scaled_softmax.py b/colossalai/nn/layer/scaled_softmax.py new file mode 100644 index 000000000..a8d72ddd9 --- /dev/null +++ b/colossalai/nn/layer/scaled_softmax.py @@ -0,0 +1,184 @@ +# This code from NVIDIA Megatron: +# with minor changes. + +import enum + +import torch +import torch.nn as nn + +from colossalai.kernel.kernel_loader import ScaledMaskedSoftmaxLoader, ScaledUpperTriangleMaskedSoftmaxLoader + + +class AttnMaskType(enum.Enum): + padding = 1 + causal = 2 + paddedcausal = 3 + + +class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): + """ + Fused operation which performs following three operations in sequence + + 1. Scale the tensor. + 2. Apply upper triangular mask (typically used in gpt models). + 3. Perform softmax. + """ + + @staticmethod + def forward(ctx, inputs, scale): + global scaled_upper_triang_masked_softmax + if scaled_upper_triang_masked_softmax: + scaled_upper_triang_masked_softmax = ScaledUpperTriangleMaskedSoftmaxLoader().load() + + scale_t = torch.tensor([scale]) + softmax_results = scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0]) + + ctx.save_for_backward(softmax_results, scale_t) + return softmax_results + + @staticmethod + def backward(ctx, output_grads): + softmax_results, scale_t = ctx.saved_tensors + input_grads = scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results, scale_t[0]) + + return input_grads, None + + +class ScaledMaskedSoftmax(torch.autograd.Function): + """ + Fused operation which performs following three operations in sequence + + 1. Scale the tensor. + 2. Apply the mask. + 3. Perform softmax. + """ + + @staticmethod + def forward(ctx, inputs, mask, scale): + scale_t = torch.tensor([scale]) + + # build and load kernel if not pre-built + global scaled_masked_softmax + if scaled_masked_softmax is None: + scaled_masked_softmax = ScaledMaskedSoftmaxLoader().load() + + softmax_results = scaled_masked_softmax.forward(inputs, mask, scale_t[0]) + ctx.save_for_backward(softmax_results, scale_t) + return softmax_results + + @staticmethod + def backward(ctx, output_grads): + softmax_results, scale_t = ctx.saved_tensors + + input_grads = scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0]) + return input_grads, None, None, None + + +class FusedScaleMaskSoftmax(nn.Module): + """ + Fused operation: scaling + mask + softmax + + Arguments: + input_in_fp16: Flag to indicate if input in fp16 data format. + input_in_bf16: Flag to indicate if input in bf16 data format. + attn_mask_type: Attention mask type (pad or causal) + scaled_masked_softmax_fusion: Flag to indicate user want to use softmax fusion + mask_func: Mask function to be applied. + softmax_in_fp32: If True, softmax in performed at fp32 precision. + scale: Scaling factor used in input tensor scaling. + """ + + def __init__( + self, + input_in_fp16, + input_in_bf16, + attn_mask_type, + scaled_masked_softmax_fusion, + mask_func, + softmax_in_fp32, + scale, + ): + super(FusedScaleMaskSoftmax, self).__init__() + self.input_in_fp16 = input_in_fp16 + self.input_in_bf16 = input_in_bf16 + assert not ( + self.input_in_fp16 and self.input_in_bf16 + ), "both fp16 and bf16 flags cannot be active at the same time." + self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 + self.attn_mask_type = attn_mask_type + self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion + self.mask_func = mask_func + self.softmax_in_fp32 = softmax_in_fp32 + self.scale = scale + assert self.scale is None or softmax_in_fp32, "softmax should be in fp32 when scaled" + + def forward(self, input, mask): + # [b, np, sq, sk] + assert input.dim() == 4 + + if self.is_kernel_available(mask, *input.size()): + return self.forward_fused_softmax(input, mask) + else: + return self.forward_torch_softmax(input, mask) + + def is_kernel_available(self, mask, b, np, sq, sk): + attn_batches = b * np + + if ( + self.scaled_masked_softmax_fusion # user want to fuse + and self.input_in_float16 # input must be fp16 + and mask is not None # mask tensor must not be None + and 16 < sk <= 2048 # sk must be 16 ~ 2048 + and sq % 4 == 0 # sq must be divisor of 4 + and attn_batches % 4 == 0 # np * b must be divisor of 4 + ): + if 0 <= sk <= 2048: + batch_per_block = self.get_batch_per_block(sq, sk, b, np) + + if self.attn_mask_type.value > 1: + if attn_batches % batch_per_block == 0: + return True + else: + if sq % batch_per_block == 0: + return True + return False + + def forward_fused_softmax(self, input, mask): + b, np, sq, sk = input.size() + scale = self.scale if self.scale is not None else 1.0 + + if self.attn_mask_type.value > 1: + assert sq == sk, "causal mask is only for self attention" + + # input is 3D tensor (attn_batches, sq, sk) + input = input.view(-1, sq, sk) + probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale) + return probs.view(b, np, sq, sk) + else: + # input is 4D tensor (b, np, sq, sk) + return ScaledMaskedSoftmax.apply(input, mask, scale) + + def forward_torch_softmax(self, input, mask): + if self.input_in_float16 and self.softmax_in_fp32: + input = input.float() + + if self.scale is not None: + input = input * self.scale + mask_output = self.mask_func(input, mask) if mask is not None else input + probs = torch.nn.Softmax(dim=-1)(mask_output) + + if self.input_in_float16 and self.softmax_in_fp32: + if self.input_in_fp16: + probs = probs.half() + else: + probs = probs.bfloat16() + + return probs + + def get_batch_per_block(self, sq, sk, b, np): + # build and load kernel if not pre-built + global scaled_masked_softmax + if scaled_masked_softmax is None: + scaled_masked_softmax = ScaledMaskedSoftmaxLoader().load() + + return scaled_masked_softmax.get_batch_per_block(sq, sk, b, np) diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py index 7d53a1dd6..5be629fb2 100644 --- a/colossalai/nn/optimizer/cpu_adam.py +++ b/colossalai/nn/optimizer/cpu_adam.py @@ -1,10 +1,9 @@ import math -import platform from typing import Optional import torch -from colossalai.kernel.op_builder import ArmCPUAdamBuilder, CPUAdamBuilder +from colossalai.kernel.kernel_loader import CPUAdamLoader from .nvme_optimizer import NVMeOptimizer @@ -78,7 +77,7 @@ class CPUAdam(NVMeOptimizer): default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction) super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir) self.adamw_mode = adamw_mode - cpu_adam = ArmCPUAdamBuilder().load() if platform.machine() == "aarch64" else CPUAdamBuilder().load() + cpu_adam = CPUAdamLoader().load() # if you find yourself stuck here, make sure that you install colossalai with CUDA_EXT=1 specification self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode) diff --git a/colossalai/nn/optimizer/fused_adam.py b/colossalai/nn/optimizer/fused_adam.py index fcdd3257d..aeb5cc91b 100644 --- a/colossalai/nn/optimizer/fused_adam.py +++ b/colossalai/nn/optimizer/fused_adam.py @@ -70,9 +70,9 @@ class FusedAdam(torch.optim.Optimizer): self.adamw_mode = 1 if adamw_mode else 0 self.set_grad_none = set_grad_none if multi_tensor_applier.available: - from colossalai.kernel.op_builder import FusedOptimBuilder + from colossalai.kernel.kernel_loader import FusedOptimizerLoader - fused_optim = FusedOptimBuilder().load() + fused_optim = FusedOptimizerLoader().load() # Skip buffer self._dummy_overflow_buf = torch.cuda.IntTensor([0]) diff --git a/colossalai/nn/optimizer/fused_lamb.py b/colossalai/nn/optimizer/fused_lamb.py index 3e1d5a7ba..da8d1608a 100644 --- a/colossalai/nn/optimizer/fused_lamb.py +++ b/colossalai/nn/optimizer/fused_lamb.py @@ -77,9 +77,9 @@ class FusedLAMB(torch.optim.Optimizer): ) super(FusedLAMB, self).__init__(params, defaults) if multi_tensor_applier.available: - from colossalai.kernel.op_builder import FusedOptimBuilder + from colossalai.kernel.kernel_loader import FusedOptimizerLoader - fused_optim = FusedOptimBuilder().load() + fused_optim = FusedOptimizerLoader().load() self.multi_tensor_l2norm = fused_optim.multi_tensor_l2norm # Skip buffer diff --git a/colossalai/nn/optimizer/fused_sgd.py b/colossalai/nn/optimizer/fused_sgd.py index 95a635420..3fae9bbca 100644 --- a/colossalai/nn/optimizer/fused_sgd.py +++ b/colossalai/nn/optimizer/fused_sgd.py @@ -72,9 +72,9 @@ class FusedSGD(Optimizer): self.wd_after_momentum = wd_after_momentum if multi_tensor_applier.available: - from colossalai.kernel.op_builder import FusedOptimBuilder + from colossalai.kernel.kernel_loader import FusedOptimizerLoader - fused_optim = FusedOptimBuilder().load() + fused_optim = FusedOptimizerLoader().load() # Skip buffer self._dummy_overflow_buf = torch.tensor( diff --git a/colossalai/nn/optimizer/hybrid_adam.py b/colossalai/nn/optimizer/hybrid_adam.py index d34fd601a..c9c1f81bf 100644 --- a/colossalai/nn/optimizer/hybrid_adam.py +++ b/colossalai/nn/optimizer/hybrid_adam.py @@ -2,7 +2,7 @@ from typing import Any, Optional import torch -from colossalai.kernel.op_builder import FusedOptimBuilder +from colossalai.kernel.kernel_loader import FusedOptimizerLoader from colossalai.utils import multi_tensor_applier from .cpu_adam import CPUAdam @@ -85,7 +85,7 @@ class HybridAdam(CPUAdam): nvme_offload_dir, ) if torch.cuda.is_available(): - fused_optim = FusedOptimBuilder().load() + fused_optim = FusedOptimizerLoader().load() self.gpu_adam_op = fused_optim.multi_tensor_adam self._dummy_overflow_buf = torch.cuda.IntTensor([0]) diff --git a/colossalai/pipeline/schedule/generate.py b/colossalai/pipeline/schedule/generate.py index 72480526b..20f316c2a 100644 --- a/colossalai/pipeline/schedule/generate.py +++ b/colossalai/pipeline/schedule/generate.py @@ -7,10 +7,10 @@ import torch.cuda from torch.nn import Module from torch.utils._pytree import tree_map +from colossalai.accelerator import get_accelerator from colossalai.inference.engine.microbatch_manager import MicroBatchManager, Status from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.utils.device import get_current_device from ._utils import get_batch_size, get_micro_batch, model_forward, to_device from .base import PipelineSchedule @@ -86,7 +86,7 @@ class GenerateSchedule(PipelineSchedule): """ micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size) self.microbatch_offset += self.microbatch_size - return tree_map(partial(to_device, device=get_current_device()), micro_batch) + return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch) def _prepare_inputs_for_interval_stage(self): """ diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index 0a01a1e78..a4ace5e1b 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -6,10 +6,11 @@ import torch.cuda from torch.nn import Module, ModuleList from torch.utils._pytree import tree_map +from colossalai.accelerator import get_accelerator from colossalai.interface import OptimizerWrapper from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.utils.device import get_current_device +from colossalai.utils import get_current_device from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device from .base import PipelineSchedule @@ -72,6 +73,10 @@ class InterleavedSchedule(PipelineSchedule): assert self.last_batch_size is None or self.last_batch_size == self.batch_size assert self.batch_size == self.microbatch_size * self.num_microbatch + assert ( + self.num_microbatch % self.stage_manager.num_stages == 0 + ), "Number of microbatch should be an integer multiple of number of pipeline parallel devices" + if self.forward_only: self.num_microbatch = (self.batch_size - 1) // self.microbatch_size + 1 # NOTE: disable metadata cache when batch size changes (not valid anymore) @@ -96,7 +101,7 @@ class InterleavedSchedule(PipelineSchedule): assert self.microbatch_offset[model_chunk_id] <= self.batch_size, "Microbatches exhausted" micro_batch = get_micro_batch(self.batch, self.microbatch_offset[model_chunk_id], self.microbatch_size) self.microbatch_offset[model_chunk_id] += self.microbatch_size - return tree_map(partial(to_device, device=get_current_device()), micro_batch) + return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch) def get_model_chunk_id(self, microbatch_id: int, is_forward: bool) -> int: """Helper method to get the model chunk ID given the iteration number. diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index cb078b25f..bf2f01b10 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -6,10 +6,11 @@ import torch.cuda from torch.nn import Module from torch.utils._pytree import tree_map +from colossalai.accelerator import get_accelerator from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.utils.device import get_current_device +from colossalai.utils import get_current_device from ._utils import ( detach, @@ -85,6 +86,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): assert self.last_batch_size is None or self.last_batch_size == self.batch_size assert self.batch_size == self.microbatch_size * self.num_microbatches + assert ( + self.num_microbatches >= self.stage_manager.num_stages + ), "Number of microbatch should be larger than number of stages" + if self.forward_only: self.num_microbatches = (self.batch_size - 1) // self.microbatch_size + 1 # NOTE: disable metadata cache when batch size changes (not valid anymore) @@ -106,7 +111,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): assert self.microbatch_offset <= self.batch_size, "Microbatches exhausted" micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size) self.microbatch_offset += self.microbatch_size - return tree_map(partial(to_device, device=get_current_device()), micro_batch) + return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch) def recv_forward(self, prev_rank: int = None) -> Any: """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. @@ -313,7 +318,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): accum_loss = None if return_loss and self.stage_manager.is_last_stage(): - accum_loss = torch.scalar_tensor(0, device=get_current_device()) + accum_loss = torch.scalar_tensor(0, device=get_accelerator().get_current_device()) outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None for _ in range(self.num_microbatches): diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index 4b6343adc..0d2cc1b33 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -6,7 +6,8 @@ import torch.distributed as dist from torch import nn from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch.distributed import ProcessGroup, get_world_size -from colossalai.utils.device import get_current_device, get_rng_state, set_rng_state, manual_seed + +from colossalai.accelerator import get_accelerator class SeqParallelUtils: @@ -109,10 +110,10 @@ class Randomizer: # 1. get the current rng state # 2. set the seed and store the rng state # 3. recover the original rng state - device_original_rng_state = get_rng_state() - manual_seed(seed) - self.device_rng_state = get_rng_state() - set_rng_state(device_original_rng_state) + device_original_rng_state = get_accelerator().get_rng_state() + get_accelerator().manual_seed(seed) + self.device_rng_state = get_accelerator().get_rng_state() + get_accelerator().set_rng_state(device_original_rng_state) # to the same for cpu rng state cpu_original_rng_state = torch.get_rng_state() @@ -121,10 +122,10 @@ class Randomizer: torch.set_rng_state(cpu_original_rng_state) def _set_device_rng_state(self, rng_state): - set_rng_state(rng_state) + get_accelerator().set_rng_state(rng_state) def _get_device_rng_state(self): - current_state = get_rng_state() + current_state = get_accelerator().get_rng_state() return current_state def _set_cpu_rng_state(self, rng_state): @@ -209,7 +210,7 @@ class Randomizer: index = Randomizer.index() if dist.is_initialized(): # convert the index to tensor - index_tensor = torch.tensor(index, dtype=torch.int32, device=get_current_device()) + index_tensor = torch.tensor(index, dtype=torch.int32, device=get_accelerator().get_current_device()) # all gather the index gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))] @@ -231,7 +232,7 @@ class Randomizer: if dist.is_initialized(): # convert the index to tensor - index_tensor = torch.tensor(index, dtype=torch.int32, device=get_current_device()) + index_tensor = torch.tensor(index, dtype=torch.int32, device=get_accelerator().get_current_device()) # all gather the index gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))] diff --git a/colossalai/shardformer/modeling/blip2.py b/colossalai/shardformer/modeling/blip2.py index 00b2037fb..d5c10541a 100644 --- a/colossalai/shardformer/modeling/blip2.py +++ b/colossalai/shardformer/modeling/blip2.py @@ -62,7 +62,7 @@ def forward_fn(): def get_blip2_flash_attention_forward(): from transformers.models.blip_2.modeling_blip_2 import Blip2Attention - from colossalai.kernel.cuda_native import ColoAttention + from colossalai.nn.layer.colo_attention import ColoAttention def forward( self: Blip2Attention, diff --git a/colossalai/shardformer/modeling/chatglm2.py b/colossalai/shardformer/modeling/chatglm2.py index c8a311df7..d13bd3492 100644 --- a/colossalai/shardformer/modeling/chatglm2.py +++ b/colossalai/shardformer/modeling/chatglm2.py @@ -14,7 +14,7 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLM def get_flash_core_attention_forward(): - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention from .chatglm2_6b.modeling_chatglm import CoreAttention diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 8f4563537..055e3096d 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -719,7 +719,7 @@ class GPT2PipelineForwards: def get_gpt2_flash_attention_forward(): from transformers.models.gpt2.modeling_gpt2 import GPT2Attention - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention def split_heads(tensor, num_heads, attn_head_size): """ diff --git a/colossalai/shardformer/modeling/gptj.py b/colossalai/shardformer/modeling/gptj.py index ad51bf2c7..22b0f7a90 100644 --- a/colossalai/shardformer/modeling/gptj.py +++ b/colossalai/shardformer/modeling/gptj.py @@ -530,7 +530,7 @@ class GPTJPipelineForwards: def get_gptj_flash_attention_forward(): from transformers.models.gptj.modeling_gptj import GPTJAttention - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention def split_heads(tensor, num_attention_heads, attn_head_size, rotary): """ diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 1b53ce4af..e10a7ed7d 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -3,7 +3,6 @@ from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F -import torch.distributed as dist from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.modeling_outputs import ( BaseModelOutputWithPast, @@ -15,14 +14,17 @@ from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.shard import ShardConfig + from ..layer import cross_entropy_1d try: from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask + LATEST_VERSION = True except ImportError: LATEST_VERSION = False + class LlamaPipelineForwards: """ This class serves as a micro library for forward function substitution of Llama models @@ -203,7 +205,7 @@ class LlamaPipelineForwards: stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, - shard_config: ShardConfig = None + shard_config: ShardConfig = None, ): r""" Args: @@ -279,12 +281,13 @@ class LlamaPipelineForwards: if shard_config.enable_tensor_parallelism: new_vocab_size = logits.shape[-1] shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d(shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group) + loss = cross_entropy_1d( + shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group + ) else: shift_logits = shift_logits.view(-1, self.config.vocab_size) loss = loss_fct(shift_logits, shift_labels) - if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output @@ -417,7 +420,7 @@ class LlamaPipelineForwards: def get_llama_flash_attention_forward(shard_config: ShardConfig): from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention llama_version = 2 try: @@ -480,7 +483,12 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig): attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads) attn_output = attention( - query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type + query_states, + key_states, + value_states, + attn_mask=flash_attention_mask, + attn_mask_type=attn_mask_type, + origin_attn_mask=attention_mask, ) attn_output = self.o_proj(attn_output) @@ -492,7 +500,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig): def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): from transformers import LlamaForCausalLM - + def forward( self: LlamaForCausalLM, input_ids: torch.LongTensor = None, @@ -573,12 +581,13 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): if shard_config.enable_tensor_parallelism: new_vocab_size = logits.shape[-1] shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d(shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group) + loss = cross_entropy_1d( + shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group + ) else: shift_logits = shift_logits.view(-1, self.config.vocab_size) loss = loss_fct(shift_logits, shift_labels) - if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output @@ -590,4 +599,5 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + return forward diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py index 1ddb26c25..0da1a35a0 100644 --- a/colossalai/shardformer/modeling/mistral.py +++ b/colossalai/shardformer/modeling/mistral.py @@ -6,7 +6,7 @@ import torch def get_mistral_flash_attention_forward(): from transformers.models.mistral.modeling_mistral import MistralAttention, apply_rotary_pos_emb, repeat_kv - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention def forward( self: MistralAttention, diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index 71f2ca335..7f6cbbbcf 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -514,7 +514,7 @@ class OPTPipelineForwards: def get_opt_flash_attention_forward(): from transformers.models.opt.modeling_opt import OPTAttention - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention def forward( self: OPTAttention, diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py index f67aa84e4..dcb178520 100644 --- a/colossalai/shardformer/modeling/t5.py +++ b/colossalai/shardformer/modeling/t5.py @@ -593,10 +593,6 @@ class T5PipelineForwards: def get_t5_flash_attention_forward(): - try: - from xformers.ops import memory_efficient_attention as me_attention - except: - raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") from transformers.models.t5.modeling_t5 import T5Attention def forward( @@ -632,11 +628,11 @@ def get_t5_flash_attention_forward(): def shape(states): """projection""" - return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim) + return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) def unshape(states): """reshape""" - return states.view(batch_size, -1, self.inner_dim) + return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) def project(hidden_states, proj_layer, key_value_states, past_key_value): """projects hidden states correctly to key/query states""" @@ -653,8 +649,8 @@ def get_t5_flash_attention_forward(): if key_value_states is None: # self-attn # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=1) - elif past_key_value.shape[1] != key_value_states.shape[1]: + hidden_states = torch.cat([past_key_value, hidden_states], dim=2) + elif past_key_value.shape[2] != key_value_states.shape[1]: # checking that the `sequence_length` of the `past_key_value` is the same as # the provided `key_value_states` to support prefix tuning # cross-attn @@ -701,10 +697,15 @@ def get_t5_flash_attention_forward(): else: position_bias_masked = position_bias - position_bias_masked = position_bias_masked.contiguous() - attn_output = me_attention( - query_states, key_states, value_states, attn_bias=position_bias_masked, p=self.dropout, scale=1.0 - ) + with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=True): + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=position_bias_masked, + dropout_p=self.dropout, + scale=1.0, + ) attn_output = unshape(attn_output) attn_output = self.o(attn_output) diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py index 5a50e7379..ab141a74a 100644 --- a/colossalai/shardformer/modeling/vit.py +++ b/colossalai/shardformer/modeling/vit.py @@ -336,7 +336,7 @@ def ViTForMaskedImageModeling_pipeline_forward(stage_manager: PipelineStageManag def get_vit_flash_self_attention_forward(): from transformers.models.vit.modeling_vit import ViTSelfAttention - from colossalai.kernel.cuda_native import ColoAttention + from colossalai.nn.layer.colo_attention import ColoAttention def transpose_for_scores(x: torch.Tensor, num_attention_heads, attention_head_size) -> torch.Tensor: new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size) diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py index 9827d4801..cb8b45ae7 100644 --- a/colossalai/shardformer/modeling/whisper.py +++ b/colossalai/shardformer/modeling/whisper.py @@ -26,7 +26,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager def get_whisper_flash_attention_forward(): from transformers.models.whisper.modeling_whisper import WhisperAttention - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention def shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int): return tensor.view(bsz, seq_len, num_heads, head_dim).contiguous() diff --git a/colossalai/shardformer/policies/falcon.py b/colossalai/shardformer/policies/falcon.py index f2eeb9d69..5c148880f 100644 --- a/colossalai/shardformer/policies/falcon.py +++ b/colossalai/shardformer/policies/falcon.py @@ -49,7 +49,7 @@ class FalconPolicy(Policy): if not self.model.config.new_decoder_architecture and self.model.config.multi_query: warnings.warn( - "Falcon dosen't support tensor parallelism when (not new_decoder_architecture and multi_query) is True, will ignore the tensor parallelism flag." + "Falcon doesn't support tensor parallelism when (not new_decoder_architecture and multi_query) is True, will ignore the tensor parallelism flag." ) self.shard_config.enable_tensor_parallelism = False diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 1faa24f71..42bf0825b 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -46,7 +46,7 @@ class LlamaPolicy(Policy): if self.shard_config.enable_sequence_parallelism: self.shard_config.enable_sequence_parallelism = False - warnings.warn("Llama dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + warnings.warn("Llama doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") if self.shard_config.enable_tensor_parallelism: decoder_attribute_replacement = { diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index c16aa6dea..c0b8b3375 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -35,7 +35,7 @@ class MistralPolicy(Policy): if self.shard_config.enable_sequence_parallelism: self.shard_config.enable_sequence_parallelism = False warnings.warn( - "Mistral dosen't support sequence parallelism now, will ignore the sequence parallelism flag." + "Mistral doesn't support sequence parallelism now, will ignore the sequence parallelism flag." ) if self.shard_config.enable_tensor_parallelism: @@ -136,7 +136,7 @@ class MistralModelPolicy(MistralPolicy): def module_policy(self): if self.pipeline_stage_manager: - warnings.warn("Mistral dosen't support pipeline parallelism now.") + warnings.warn("Mistral doesn't support pipeline parallelism now.") return super().module_policy() @@ -160,7 +160,7 @@ class MistralForCausalLMPolicy(MistralPolicy): } if self.pipeline_stage_manager: - warnings.warn("Mistral dosen't support pipeline parallelism now.") + warnings.warn("Mistral doesn't support pipeline parallelism now.") policy.update(new_item) @@ -186,7 +186,7 @@ class MistralForSequenceClassificationPolicy(MistralPolicy): } if self.pipeline_stage_manager: - warnings.warn("Mistral dosen't support pipeline parallelism now.") + warnings.warn("Mistral doesn't support pipeline parallelism now.") policy.update(new_item) return policy diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index e2f3a829c..a542808ba 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -59,7 +59,7 @@ class OPTPolicy(Policy): if self.shard_config.enable_sequence_parallelism: self.shard_config.enable_sequence_parallelism = False - warnings.warn("OPT dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + warnings.warn("OPT doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") if self.shard_config.enable_tensor_parallelism: policy[OPTDecoder] = ModulePolicyDescription( diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 4d906e3f4..e183b0632 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -66,7 +66,7 @@ class T5BasePolicy(Policy): if self.shard_config.enable_sequence_parallelism: self.shard_config.enable_sequence_parallelism = False - warnings.warn("T5 dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + warnings.warn("T5 doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") if self.shard_config.enable_tensor_parallelism: policy[T5Stack] = ModulePolicyDescription( @@ -263,7 +263,7 @@ class T5BasePolicy(Policy): if num_decoder_layers == 0: return Policy.distribute_layers(num_encoder_layers, num_stages), num_stages - # the number of stages distributed between encoder and decoder is optmized in this way: + # the number of stages distributed between encoder and decoder is optimized in this way: # num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages)) # s.t. num_encoder_stages + num_decoder_stages = num_stages, num_encoder_stages >= 1, num_decoder_stages >= 1 def objective(num_encoder_stages): diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 6ef0e3b34..584d4e265 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -33,7 +33,7 @@ class ViTPolicy(Policy): if self.shard_config.enable_sequence_parallelism: self.shard_config.enable_sequence_parallelism = False - warnings.warn("Vit dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + warnings.warn("Vit doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") if self.shard_config.enable_tensor_parallelism: policy[ViTEmbeddings] = ModulePolicyDescription( diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index 6dae99e8c..b5b5db79d 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -69,13 +69,13 @@ class WhisperPolicy(Policy): if self.shard_config.enable_sequence_parallelism: self.shard_config.enable_sequence_parallelism = False warnings.warn( - "Whisper dosen't support sequence parallelism now, will ignore the sequence parallelism flag." + "Whisper doesn't support sequence parallelism now, will ignore the sequence parallelism flag." ) # TODO using the jit fused add_and_dropout affect the accuracy if self.shard_config.enable_jit_fused: self.shard_config.enable_jit_fused = False - warnings.warn("Whisper dosen't support jit fused operator now, will ignore the jit fused operator flag.") + warnings.warn("Whisper doesn't support jit fused operator now, will ignore the jit fused operator flag.") if self.shard_config.enable_tensor_parallelism: policy[WhisperEncoderLayer] = ModulePolicyDescription( @@ -302,7 +302,7 @@ class WhisperPolicy(Policy): if num_decoder_layers == 0: return Policy.distribute_layers(num_encoder_layers, num_stages), num_stages - # the number of stages distributed between encoder and decoder is optmized in this way: + # the number of stages distributed between encoder and decoder is optimized in this way: # num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages)) # s.t. num_encoder_stages + num_decoder_stages = num_stages, num_encoder_stages >= 1, num_decoder_stages >= 1 def objective(num_encoder_stages): diff --git a/colossalai/tensor/comm_spec.py b/colossalai/tensor/comm_spec.py index de0cba26b..27afac9e9 100644 --- a/colossalai/tensor/comm_spec.py +++ b/colossalai/tensor/comm_spec.py @@ -451,7 +451,7 @@ class CommSpec: elif self.comm_pattern == CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD: res_list.append(f"comm_pattern:MIXGATHER_FWD_SPLIT_BWD, ") res_list.append(f"gather_dim:{self.gather_dim}, ") - res_list.append(f"logical_process_asex:{self.logical_process_axes})") + res_list.append(f"logical_process_axes:{self.logical_process_axes})") return "".join(res_list) diff --git a/colossalai/tensor/d_tensor/api.py b/colossalai/tensor/d_tensor/api.py index 74a785f2d..da6ef275e 100644 --- a/colossalai/tensor/d_tensor/api.py +++ b/colossalai/tensor/d_tensor/api.py @@ -96,9 +96,9 @@ def _apply_layout(tensor, layout): """ Apply the layout to the local tensor during initializing process. """ - # layout converter requires a source and target laytout + # layout converter requires a source and target layout # we construct the source layer for an unsharded tensor - # and use self.dist_layer as the targer layout for the sharded tensor + # and use self.dist_layer as the target layout for the sharded tensor source_spec = _construct_default_sharding_spec(tensor) source_layout = Layout(device_mesh=layout.device_mesh, sharding_spec=source_spec, global_shape=tensor.shape) sharded_tensor = layout_converter.apply(tensor=tensor, source_layout=source_layout, target_layout=layout) diff --git a/colossalai/tensor/moe_tensor/api.py b/colossalai/tensor/moe_tensor/api.py index 1e4486101..b6843df7a 100644 --- a/colossalai/tensor/moe_tensor/api.py +++ b/colossalai/tensor/moe_tensor/api.py @@ -40,7 +40,7 @@ def get_moe_info(ep_size: int, dp_size: int, pp_size: int, ep_inside: bool) -> M ep_size (int): The expert parallel size. dp_size (int): The data parallel size. pp_size (int): The pipeline parallel size. - ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if Fasle. + ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if False. Returns: dict: The moe info of the given tensor. diff --git a/colossalai/tensor/moe_tensor/moe_info.py b/colossalai/tensor/moe_tensor/moe_info.py index 5097ac104..ba6c77056 100644 --- a/colossalai/tensor/moe_tensor/moe_info.py +++ b/colossalai/tensor/moe_tensor/moe_info.py @@ -12,7 +12,7 @@ class MoeParallelInfo: ep_size (int): expert parallel size dp_size (int): data parallel (zero) size pp_size (int, optional): pipeline parallel size. Defaults to 1. - ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if Fasle. Defaults to True. + ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if False. Defaults to True. """ self.pp_size, self.dp_size, self.ep_size = pp_size, dp_size, ep_size if ep_inside: diff --git a/colossalai/testing/utils.py b/colossalai/testing/utils.py index 7cd24b0ad..5f6864ff0 100644 --- a/colossalai/testing/utils.py +++ b/colossalai/testing/utils.py @@ -9,7 +9,8 @@ from typing import Any, Callable, List import torch import torch.multiprocessing as mp from packaging import version -from colossalai.utils.device import empty_cache, reset_max_memory_allocated, reset_peak_memory_stats, synchronize, reset_max_memory_cached, device_count + +from colossalai.accelerator import get_accelerator def parameterize(argument: str, values: List[Any]) -> Callable: @@ -199,7 +200,7 @@ def skip_if_not_enough_gpus(min_gpus: int): def _wrap_func(f): def _execute_by_gpu_num(*args, **kwargs): - num_avail_gpu = device_count() + num_avail_gpu = get_accelerator().device_count() if num_avail_gpu >= min_gpus: f(*args, **kwargs) @@ -263,11 +264,11 @@ def clear_cache_before_run(): def _wrap_func(f): def _clear_cache(*args, **kwargs): - empty_cache() - reset_peak_memory_stats() - reset_max_memory_allocated() - reset_max_memory_cached() - synchronize() + get_accelerator().empty_cache() + get_accelerator().reset_peak_memory_stats() + get_accelerator().reset_max_memory_allocated() + get_accelerator().reset_max_memory_cached() + get_accelerator().synchronize() gc.collect() f(*args, **kwargs) diff --git a/colossalai/utils/__init__.py b/colossalai/utils/__init__.py index 0246a35e2..cdba46709 100644 --- a/colossalai/utils/__init__.py +++ b/colossalai/utils/__init__.py @@ -4,20 +4,16 @@ from .common import ( disposable, ensure_path_exists, free_storage, + get_current_device, is_ddp_ignored, set_seed, ) -from .device import IS_NPU_AVAILABLE, empty_cache, get_current_device, set_device, set_to_cuda, synchronize from .multi_tensor_apply import multi_tensor_applier from .tensor_detector import TensorDetector from .timer import MultiTimer, Timer __all__ = [ "conditional_context", - "get_current_device", - "synchronize", - "empty_cache", - "set_to_cuda", "Timer", "MultiTimer", "multi_tensor_applier", @@ -27,7 +23,6 @@ __all__ = [ "_cast_float", "free_storage", "set_seed", + "get_current_device", "is_ddp_ignored", - "set_device", - "IS_NPU_AVAILABLE", ] diff --git a/colossalai/utils/common.py b/colossalai/utils/common.py index c43caaff4..4a1889eb5 100644 --- a/colossalai/utils/common.py +++ b/colossalai/utils/common.py @@ -10,6 +10,15 @@ from typing import Callable import numpy as np import torch +from colossalai.accelerator import get_accelerator + + +def get_current_device(): + """ + A wrapper function for accelerator's API for backward compatibility. + """ + return get_accelerator().get_current_device() + def ensure_path_exists(filename: str): # ensure the path exists diff --git a/colossalai/utils/device.py b/colossalai/utils/device.py deleted file mode 100644 index c70dbdaa5..000000000 --- a/colossalai/utils/device.py +++ /dev/null @@ -1,223 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -from typing import Any, Dict, List, Optional, Tuple, Callable - -import torch -import torch.distributed as dist - -IS_NPU_AVAILABLE: bool = False -try: - import torch_npu # noqa - - IS_NPU_AVAILABLE = torch.npu.is_available() -except ImportError: - pass - - -def set_to_cuda(models): - """Send model to gpu. - - :param models: nn.module or a list of module - """ - if isinstance(models, list) and len(models) > 1: - ret = [] - for model in models: - ret.append(model.to(get_current_device())) - return ret - elif isinstance(models, list): - return models[0].to(get_current_device()) - else: - return models.to(get_current_device()) - - -def get_current_device() -> torch.device: - """ - Returns currently selected device (gpu/cpu). - If cuda available, return gpu, otherwise return cpu. - """ - if torch.cuda.is_available(): - return torch.device(f"cuda:{torch.cuda.current_device()}") - elif IS_NPU_AVAILABLE: - return torch.device(f"npu:{torch.npu.current_device()}") - else: - return torch.device("cpu") - - -def _dispatch_device_func(fn_name: str, *args, **kwargs): - if torch.cuda.is_available(): - return getattr(torch.cuda, fn_name)(*args, **kwargs) - elif IS_NPU_AVAILABLE: - return getattr(torch.npu, fn_name)(*args, **kwargs) - else: - raise RuntimeError("No device available") - - -# device semantics - - -def can_device_access_peer(device, peer_device) -> bool: - return _dispatch_device_func("can_device_access_peer", device, peer_device) - - -def current_device() -> int: - return _dispatch_device_func("current_device") - - -def current_stream(device=None): - return _dispatch_device_func("current_stream", device) - - -def default_stream(device=None): - return _dispatch_device_func("default_stream", device) - - -def device_count() -> int: - return _dispatch_device_func("device_count") - - -def get_device_capability(device=None) -> Tuple[int, int]: - return _dispatch_device_func("get_device_capability", device) - - -def get_device_name(device=None) -> str: - return _dispatch_device_func("get_device_name", device) - - -def get_device_properties(device): - return _dispatch_device_func("get_device_properties", device) - - -def set_device(index: Optional[int] = None) -> None: - if index is None: - index = dist.get_rank() % device_count() - _dispatch_device_func("set_device", index) - - -def set_stream(stream_): - return _dispatch_device_func("set_stream", stream_) - - -def stream(stream_): - return _dispatch_device_func("stream", stream_) - - -def synchronize(): - return _dispatch_device_func("synchronize") - - -def utilization(device=None) -> int: - return _dispatch_device_func("utilization", device) - - -# random number generator - - -def get_rng_state(device="cuda") -> torch.Tensor: - return _dispatch_device_func("get_rng_state", device) - - -def get_rng_state_all() -> List[torch.Tensor]: - return _dispatch_device_func("get_rng_state_all") - - -def set_rng_state(new_state: torch.ByteTensor, device="cuda") -> None: - return _dispatch_device_func("set_rng_state", new_state, device) - - -def set_rng_state_all(new_states: List[torch.ByteTensor]) -> None: - return _dispatch_device_func("set_rng_state_all", new_states) - - -def manual_seed(seed: int) -> None: - return _dispatch_device_func("manual_seed", seed) - - -def manual_seed_all(seed: int) -> None: - return _dispatch_device_func("manual_seed_all", seed) - - -def seed() -> None: - return _dispatch_device_func("seed") - - -def seed_all() -> None: - return _dispatch_device_func("seed_all") - - -def initial_seed() -> int: - return _dispatch_device_func("initial_seed") - - -# streams and events - - -def Stream(device=None, priority=0, **kwargs): - return _dispatch_device_func("Stream", device, priority, **kwargs) - - -def Event(enable_timing: bool = False, blocking: bool = False, interprocess: bool = False): - return _dispatch_device_func("Event", enable_timing, blocking, interprocess) - - -# memory management - - -def empty_cache() -> None: - return _dispatch_device_func("empty_cache") - - -def memory_stats(device=None) -> Dict[str, Any]: - return _dispatch_device_func("memory_stats", device) - - -def memory_summary(device=None, abbreviated=False) -> str: - return _dispatch_device_func("memory_summary", device, abbreviated) - - -def memory_snapshot(): - return _dispatch_device_func("memory_snapshot") - - -def memory_allocated(device=None) -> int: - return _dispatch_device_func("memory_allocated", device) - - -def max_memory_allocated(device=None) -> int: - return _dispatch_device_func("max_memory_allocated", device) - - -def reset_max_memory_allocated(device=None) -> None: - return _dispatch_device_func("reset_max_memory_allocated", device) - - -def reset_max_memory_cached(device=None) -> None: - return _dispatch_device_func("reset_max_memory_cached", device) - - -def memory_reserved(device=None) -> int: - return _dispatch_device_func("memory_reserved", device) - - -def max_memory_reserved(device=None) -> int: - return _dispatch_device_func("max_memory_reserved", device) - - -def set_per_process_memory_fraction(fraction: float, device=None) -> None: - return _dispatch_device_func("set_per_process_memory_fraction", fraction, device) - - -def reset_peak_memory_stats(device=None) -> None: - return _dispatch_device_func("reset_peak_memory_stats", device) - - -# amp - - -def autocast() -> Callable: - if torch.cuda.is_available(): - return torch.cuda.amp.autocast() - elif IS_NPU_AVAILABLE: - return torch.npu.amp.autocast() - else: - raise RuntimeError("No device available") diff --git a/colossalai/utils/timer.py b/colossalai/utils/timer.py index 8ab6b46f2..2feded775 100644 --- a/colossalai/utils/timer.py +++ b/colossalai/utils/timer.py @@ -3,7 +3,7 @@ import time from typing import Tuple -from .device import synchronize +from colossalai.accelerator import get_accelerator class Timer: @@ -21,13 +21,13 @@ class Timer: @property def current_time(self) -> float: - synchronize() + get_accelerator().synchronize() return time.time() def start(self): """Firstly synchronize cuda, reset the clock and then start the timer.""" self._elapsed = 0 - synchronize() + get_accelerator().synchronize() self._start_time = time.time() self._started = True @@ -44,7 +44,7 @@ class Timer: Returns: int: Start-stop interval. """ - synchronize() + get_accelerator().synchronize() end_time = time.time() elapsed = end_time - self._start_time if keep_in_history: @@ -123,7 +123,7 @@ class MultiTimer: return None def get_timer(self, name): - """Get timer by its name (from multitimer) + """Get timer by its name (from multimer) Args: name (str): Timer's key. diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py index defc6c4cb..cad2622f2 100644 --- a/colossalai/zero/gemini/chunk/chunk.py +++ b/colossalai/zero/gemini/chunk/chunk.py @@ -6,8 +6,7 @@ import torch import torch.distributed as dist from torch.distributed import ProcessGroup -from colossalai.utils import get_current_device -from colossalai.utils.device import IS_NPU_AVAILABLE +from colossalai.accelerator import get_accelerator class TensorState(Enum): @@ -107,7 +106,7 @@ class Chunk: self.valid_end = self.shard_size self.dtype = dtype - device = init_device or get_current_device() + device = init_device or get_accelerator().get_current_device() # chunk_temp is a global chunk, which only exists during building the chunks. self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero @@ -125,7 +124,7 @@ class Chunk: # configure the init device of the shard # no-offload default: fp16, fp32 -> CUDA # offload default: fp16, fp32 -> CPU - self.shard_device = torch.device("cpu") if cpu_shard_init else get_current_device() + self.shard_device = torch.device("cpu") if cpu_shard_init else get_accelerator().get_current_device() self.chunk_mem = self.chunk_size * self.chunk_temp.element_size() self.shard_mem = self.chunk_mem // self.pg_size @@ -191,11 +190,10 @@ class Chunk: def device_type(self) -> str: if self.chunk_temp is not None: return self.chunk_temp.device.type + elif self.is_gathered or self.cuda_shard is not None: + return get_accelerator().name else: - if self.is_gathered or self.cuda_shard is not None: - return "npu" if IS_NPU_AVAILABLE else "cuda" - else: - return "cpu" + return "cpu" @property def payload(self) -> torch.Tensor: @@ -297,7 +295,7 @@ class Chunk: self.valid_end = self.utilized_size - self.shard_begin if self.chunk_temp.device.type == "cpu": - self.cuda_global_chunk = self.chunk_temp.to(get_current_device()) + self.cuda_global_chunk = self.chunk_temp.to(get_accelerator().get_current_device()) self.__update_tensors_ptr() else: self.cuda_global_chunk = self.chunk_temp @@ -334,12 +332,12 @@ class Chunk: return if device.type == "cuda" or device.type == "npu": - assert device == get_current_device(), "can't move chunk to another device" + assert device == get_accelerator().get_current_device(), "can't move chunk to another device" if self.cuda_shard: return - self.cuda_shard = self.cpu_shard.to(get_current_device()) + self.cuda_shard = self.cpu_shard.to(get_accelerator().get_current_device()) if not self.pin_memory: self.cpu_shard = None @@ -394,7 +392,9 @@ class Chunk: if self.extra_dp_group is not None: dist.all_reduce(self.cuda_global_chunk, group=self.extra_dp_group) else: - self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=get_current_device()) + self.cuda_shard = torch.empty( + self.shard_size, dtype=self.dtype, device=get_accelerator().get_current_device() + ) input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0)) dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg) @@ -533,7 +533,7 @@ class Chunk: # only be called when optimizer state is in CPU memory # the grad and param should be in the same device assert self.cuda_shard is None - temp = optim_chunk.cpu_shard.to(get_current_device()) + temp = optim_chunk.cpu_shard.to(get_accelerator().get_current_device()) # avoid to transform FP32 in CPU self.cuda_shard = temp.to(self.dtype) @@ -631,7 +631,7 @@ class Chunk: grad_chunk.valid_end = self.valid_end if grad_chunk.chunk_temp.device.type == "cpu": - grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp.to(get_current_device()) + grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp.to(get_accelerator().get_current_device()) else: grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp grad_chunk.chunk_temp = None diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py index 5f4f37c26..5bc662a61 100644 --- a/colossalai/zero/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -5,7 +5,8 @@ import torch import torch.distributed as dist from torch.distributed import ProcessGroup -from colossalai.utils import free_storage, get_current_device +from colossalai.accelerator import get_accelerator +from colossalai.utils import free_storage from .chunk import Chunk, ChunkFullError, TensorState @@ -20,7 +21,7 @@ class ChunkManager: """ def __init__(self, chunk_configuration, init_device: Optional[torch.device] = None) -> None: - self.device = init_device or get_current_device() + self.device = init_device or get_accelerator().get_current_device() self.dp_degree_chunk_size_dict: Dict[int, int] = dict() self.kwargs_config = chunk_configuration for k, v in self.kwargs_config.items(): @@ -107,7 +108,7 @@ class ChunkManager: return self.__sub_memory_usage(chunk.memory_usage) if chunk.device_type == "cpu": - chunk.shard_move(get_current_device()) + chunk.shard_move(get_accelerator().get_current_device()) self.__add_accessed_chunk(chunk) self.__add_memory_usage(chunk.memory_usage) @@ -276,7 +277,10 @@ class ChunkManager: accumulated_grad = chunk.grad_chunk.cuda_shard.clone().detach().mul_(chunk.pg_size) else: accumulated_grad = ( - chunk.grad_chunk.cpu_shard.to(get_current_device()).clone().detach().mul_(chunk.pg_size) + chunk.grad_chunk.cpu_shard.to(get_accelerator().get_current_device()) + .clone() + .detach() + .mul_(chunk.pg_size) ) accumulated_grad_gathered = False diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 5217b8036..79831cf33 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -10,6 +10,7 @@ import torch.nn as nn from torch.distributed import ProcessGroup from torch.distributed.distributed_c10d import _get_default_group +from colossalai.accelerator import get_accelerator from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param from colossalai.interface import ModelWrapper from colossalai.lazy import LazyTensor @@ -27,7 +28,7 @@ from colossalai.tensor.d_tensor import ( is_distributed_tensor, ) from colossalai.tensor.param_op_hook import ColoParamOpHookManager -from colossalai.utils import _cast_float, free_storage, get_current_device, is_ddp_ignored +from colossalai.utils import _cast_float, free_storage, is_ddp_ignored from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager from .gemini_hook import GeminiZeROHook @@ -766,7 +767,7 @@ class GeminiDDP(ModelWrapper): # move ignored parameters to CUDA if is_ddp_ignored(p): - p.data = p.data.to(device=get_current_device(), dtype=self.mixed_precision) + p.data = p.data.to(device=get_accelerator().get_current_device(), dtype=self.mixed_precision) continue # create a fp16 parameter @@ -815,7 +816,7 @@ class GeminiDDP(ModelWrapper): for buffer in self.module.buffers(): if isinstance(buffer, LazyTensor): buffer.materialize() - buffer.data = buffer.to(get_current_device()) + buffer.data = buffer.to(get_accelerator().get_current_device()) if torch.is_floating_point(buffer): buffer.data = buffer.to(self.mixed_precision) diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 8f828bd6c..98fbb0c50 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -11,6 +11,7 @@ from torch.distributed import ProcessGroup from torch.nn import Parameter from torch.optim import Optimizer +from colossalai.accelerator import get_accelerator from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param from colossalai.interface import OptimizerWrapper @@ -26,7 +27,7 @@ from colossalai.tensor.d_tensor import ( is_customized_distributed_tensor, is_distributed_tensor, ) -from colossalai.utils import disposable, get_current_device, is_ddp_ignored +from colossalai.utils import disposable, is_ddp_ignored from .chunk import Chunk, ChunkManager from .gemini_ddp import GeminiDDP @@ -233,7 +234,7 @@ class GeminiOptimizer(OptimizerWrapper): grad_chunk.l2_norm = None # clear l2 norm - comm_buffer = torch.zeros(1, dtype=torch.float, device=get_current_device()) + comm_buffer = torch.zeros(1, dtype=torch.float, device=get_accelerator().get_current_device()) for group, part_norm in group_to_norm.items(): comm_buffer.fill_(part_norm) dist.all_reduce(comm_buffer, group=group) @@ -314,10 +315,10 @@ class GeminiOptimizer(OptimizerWrapper): continue if fp32_params_used_cuda_margin_mem + chunk32.payload_mem < fp32_params_available_cuda_margin_mem: - self.chunk_manager.move_chunk(chunk32, get_current_device()) + self.chunk_manager.move_chunk(chunk32, get_accelerator().get_current_device()) # stores grad now - self.chunk_manager.move_chunk(chunk16, get_current_device()) - self.module.set_chunk_grad_device(chunk16, get_current_device()) + self.chunk_manager.move_chunk(chunk16, get_accelerator().get_current_device()) + self.module.set_chunk_grad_device(chunk16, get_accelerator().get_current_device()) fp32_params_used_cuda_margin_mem += chunk32.payload_mem for group in self.param_groups: @@ -328,7 +329,7 @@ class GeminiOptimizer(OptimizerWrapper): state = self.optim.state[fake_param] for k, v in state.items(): if isinstance(v, torch.Tensor): - state[k] = v.to(get_current_device()) + state[k] = v.to(get_accelerator().get_current_device()) def _register_states_(self): for group in self.optim.param_groups: @@ -413,7 +414,7 @@ class GeminiOptimizer(OptimizerWrapper): only_rank_0(bool): if True, states will be collected only on master rank, otherwise collected on every rank. Returns: - collected_states(dict): the gathered optimzier state of parameter with given id + collected_states(dict): the gathered optimizer state of parameter with given id if this method is called by master rank, otherwise an empty dict. This method can work only when called by all processes simultaneously. @@ -461,7 +462,7 @@ class GeminiOptimizer(OptimizerWrapper): global_shape = self.optimizer_params_info["id2shape"][param_id] # If the chunk is kept gathered, - # the parameteres are treated the same as that of those in strict DDP during training. + # the parameters are treated the same as that of those in strict DDP during training. # So states can be directly fetched from current device. if chunk.keep_gathered: assert param_id in self.id_to_fake_params @@ -551,7 +552,7 @@ class GeminiOptimizer(OptimizerWrapper): self, param_id: int, state_names: list, - device: torch.device = get_current_device(), + device: torch.device = get_accelerator().get_current_device(), dtype: torch.dtype = torch.float32, ) -> torch.Tensor: """ @@ -644,7 +645,7 @@ class GeminiOptimizer(OptimizerWrapper): """ Args: only_rank_0 (bool): a boolean value indicating whether the state_dict is collected - only on rank 0, dafault to True. + only on rank 0, default to True. Returns: The complete state of the optimizer as a :class:`dict`. @@ -783,7 +784,7 @@ class GeminiOptimizer(OptimizerWrapper): prefix (str, optional): the prefix for states. Default to ''. max_shard_size (int, optional): max size of state dict shard (in MB). Defaults to 1024. only_rank_0 (bool, optional): a boolean value indicating whether the state_dict is collected - only on rank 0, dafault to True. + only on rank 0, default to True. Yields: Iterator[OrderedDict]: A generator of state dict shard of optimizer states. diff --git a/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py b/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py index b5e40a817..e302805df 100644 --- a/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py +++ b/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py @@ -1,6 +1,6 @@ from typing import Optional -from colossalai.utils import get_current_device +from colossalai.accelerator import get_accelerator from colossalai.zero.gemini.chunk import ChunkManager from .memory_stats import MemStats @@ -33,4 +33,4 @@ class ChunkMemStatsCollector(MemStatsCollector): def cuda_margin_mem(self) -> float: from colossalai.legacy.utils.memory import colo_device_memory_capacity - return colo_device_memory_capacity(get_current_device()) - self._memstats.max_overall_cuda + return colo_device_memory_capacity(get_accelerator().get_current_device()) - self._memstats.max_overall_cuda diff --git a/colossalai/zero/gemini/memory_tracer/memory_monitor.py b/colossalai/zero/gemini/memory_tracer/memory_monitor.py index 513a6326d..82c8e9dab 100644 --- a/colossalai/zero/gemini/memory_tracer/memory_monitor.py +++ b/colossalai/zero/gemini/memory_tracer/memory_monitor.py @@ -5,7 +5,7 @@ from time import sleep, time import torch -from colossalai.utils import get_current_device +from colossalai.accelerator import get_accelerator class MemoryMonitor: @@ -77,7 +77,7 @@ class AsyncMemoryMonitor(MemoryMonitor): super().__init__() self.keep_measuring = False - current_device = get_current_device() + current_device = get_accelerator().get_current_device() def _set_cuda_device(): torch.cuda.set_device(current_device) @@ -116,7 +116,7 @@ class AsyncMemoryMonitor(MemoryMonitor): while self.keep_measuring: max_usage = max( max_usage, - colo_device_memory_used(get_current_device()), + colo_device_memory_used(get_accelerator().get_current_device()), ) sleep(self.interval) return max_usage diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index c410ad379..388999549 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -6,8 +6,8 @@ from typing import Dict, List, Optional, Tuple, Type import torch -from colossalai.utils import get_current_device -from colossalai.utils.memory import colo_device_memory_capacity +from colossalai.accelerator import get_accelerator +from colossalai.legacy.utils.memory import colo_device_memory_capacity from colossalai.zero.gemini.chunk import Chunk from .chunk import Chunk, ChunkManager @@ -85,7 +85,7 @@ class StaticPlacementPolicy(PlacementPolicy): # init offload optim settings # keep gathered chunks are in CUDA if chunk.keep_gathered or offloaded_optim_chunk_mem >= offload_optim_chunk_mem: - device = get_current_device() + device = get_accelerator().get_current_device() else: device = torch.device("cpu") # real offloaded mem is chunk.shard_mem, for simplicity we use chunk mem here @@ -140,7 +140,7 @@ class AutoPlacementPolicy(PlacementPolicy): int: the volume of memory that is evicted """ start = time() - cuda_capacity = colo_device_memory_capacity(get_current_device()) + cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device()) used_cuda_model_data = self.chunk_manager.total_mem["cuda"] if warmup: # We designate a part of CUDA memory for model data in warmup iterations. @@ -194,7 +194,7 @@ class AutoPlacementPolicy(PlacementPolicy): # init offload optim settings # keep gathered chunks are in CUDA if chunk.keep_gathered: - grads_device_map[p] = get_current_device() + grads_device_map[p] = get_accelerator().get_current_device() else: grads_device_map[p] = torch.device("cpu") diff --git a/colossalai/zero/gemini/utils.py b/colossalai/zero/gemini/utils.py index 5305953fe..b563ea5b2 100644 --- a/colossalai/zero/gemini/utils.py +++ b/colossalai/zero/gemini/utils.py @@ -6,7 +6,7 @@ import torch import torch.distributed as dist import torch.nn as nn -from colossalai.utils import get_current_device +from colossalai.accelerator import get_accelerator from .chunk import Chunk @@ -18,11 +18,11 @@ def get_temp_total_chunk_on_cuda(chunk: Chunk, dtype: torch.dtype): if chunk.cuda_shard is not None: shard_temp = chunk.cuda_shard else: - shard_temp = chunk.cpu_shard.to(get_current_device()) + shard_temp = chunk.cpu_shard.to(get_accelerator().get_current_device()) shard_temp = shard_temp.to(dtype) - total_temp = torch.zeros(chunk.chunk_size, dtype=dtype, device=get_current_device()) + total_temp = torch.zeros(chunk.chunk_size, dtype=dtype, device=get_accelerator().get_current_device()) gather_list = list(torch.chunk(input=total_temp, chunks=chunk.pg_size, dim=0)) dist.all_gather(tensor_list=gather_list, tensor=shard_temp, group=chunk.torch_pg) diff --git a/colossalai/zero/low_level/bookkeeping/bucket_store.py b/colossalai/zero/low_level/bookkeeping/bucket_store.py index 2828d5175..f395fc60e 100644 --- a/colossalai/zero/low_level/bookkeeping/bucket_store.py +++ b/colossalai/zero/low_level/bookkeeping/bucket_store.py @@ -15,7 +15,7 @@ class BucketStore(BaseStore): # init self.current_group_id = 0 self._num_elements_in_bucket = 0 - # mapping gardient slices and parameter + # mapping gradient slices and parameter self.grad_to_param_mapping = dict() self._grad_in_bucket = dict() @@ -59,7 +59,7 @@ class BucketStore(BaseStore): self.offset_list[-1] += 1 def build_grad_in_bucket(self): - """Orgnize parameters' gradient(padding and split), follows the paramters' splitting method + """Organize parameters' gradient(padding and split), follows the parameters' splitting method Data structure of self._grad_in_bucket: { @@ -91,7 +91,7 @@ class BucketStore(BaseStore): return self._grad_in_bucket def get_flatten_grad(self) -> Tensor: - """Return the flattened gradients slices in the bucket, the data orginization of the flattened tensor: + """Return the flattened gradients slices in the bucket, the data organization of the flattened tensor: [grad0_rank0, grad1_rank0, ..., grad_0_rank1, grad1_rank1, ....] Returns: diff --git a/colossalai/zero/low_level/bookkeeping/gradient_store.py b/colossalai/zero/low_level/bookkeeping/gradient_store.py index 1164532fa..73a1db5a0 100644 --- a/colossalai/zero/low_level/bookkeeping/gradient_store.py +++ b/colossalai/zero/low_level/bookkeeping/gradient_store.py @@ -9,7 +9,7 @@ class GradientStore(BaseStore): def __init__(self, *args, partition_grad: bool = False): super().__init__(*args) """ - self._grads_of_params mapping the paramater and its gradient slices + self._grads_of_params mapping the parameter and its gradient slices data structure: { group_id:{ diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index c1b35ee17..e01c852be 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -12,7 +12,7 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch.distributed import ProcessGroup from torch.optim import Optimizer -import colossalai.utils.device as device_utils +from colossalai.accelerator import get_accelerator from colossalai.amp.naive_amp.mixed_precision_mixin import ( BF16MixedPrecisionMixin, FP16MixedPrecisionMixin, @@ -22,9 +22,6 @@ from colossalai.interface import OptimizerWrapper from colossalai.logging import get_dist_logger from colossalai.tensor.moe_tensor.api import is_moe_tensor -# from colossalai.tensor import ColoParameter, ProcessGroup -from colossalai.utils.device import IS_NPU_AVAILABLE, get_current_device - from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor from .bookkeeping import BucketStore, GradientStore, ParameterStore @@ -171,7 +168,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # managed by this data parallel rank param_group["params"] = master_param_current_rank - # if there are moe params, store in addtional group in optim + # if there are moe params, store in additional group in optim if len(moe_params) > 0: param_group = dict() for key, value in self.optim.param_groups[0].items(): @@ -180,10 +177,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper): param_group["params"] = moe_params self.optim.param_groups.append(param_group) - # intialize communication stream for - # communication-compuation overlapping + # initialize communication stream for + # communication-computation overlapping if self._overlap_communication: - self._comm_stream = device_utils.Stream() + self._comm_stream = get_accelerator().Stream() # reduction hook is only used if overlapping communication # or stage 2 is used @@ -217,7 +214,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): return len(self._working_param_groups) def _sanity_checks(self): - assert torch.cuda.is_available() or IS_NPU_AVAILABLE, "device is required" + assert get_accelerator().name in ["cuda", "npu"], "device is required" for param_group in self.optim.param_groups: group_params = param_group["params"] for param in group_params: @@ -228,7 +225,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): def _create_master_param_current_rank(self, param_list): # split each param evenly by world size params_current_rank = [] - device = "cpu" if self._cpu_offload else get_current_device() + device = "cpu" if self._cpu_offload else get_accelerator().get_current_device() for param in param_list: padding_size = (self._world_size - param.numel() % self._world_size) % self._world_size @@ -340,11 +337,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper): if len(moe_grad_list) > 0: moe_flat_grads.record_stream(stream) # waiting for ops in the default stream finishing - stream.wait_stream(device_utils.current_stream()) + stream.wait_stream(get_accelerator().current_stream()) else: - stream = device_utils.current_stream() + stream = get_accelerator().current_stream() - with device_utils.stream(stream): + with get_accelerator().stream(stream): group_id = self._bucket_store.current_group_id if self.moe_extra_dp_pg is None: @@ -486,7 +483,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # clear reduced grads if self._overlap_communication: - device_utils.synchronize() + get_accelerator().synchronize() self.zero_grad() @@ -505,7 +502,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # clear reduced grads if self._overlap_communication: - device_utils.synchronize() + get_accelerator().synchronize() self.zero_grad() @@ -621,7 +618,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): release_param_grad(self._master_param_groups_of_current_rank[group_id]) # update working partition updated by the current rank - device = get_current_device() + device = get_accelerator().get_current_device() for group_id in range(self.num_param_groups): master_working_param = self.optim.param_groups[group_id]["params"] for idx, splited_param in enumerate(master_working_param): @@ -661,7 +658,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper): norm_type = float(norm_type) if norm_type == inf: total_norm = max(grad.data.abs().max() for grad in gradients) - total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float) + total_norm_cuda = torch.tensor( + [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float + ) dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg) total_norm = total_norm_cuda.item() @@ -673,7 +672,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # Sum across all model parallel GPUs. total_norm_exponentiated_cuda = torch.tensor( - [float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float + [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float ) torch.distributed.all_reduce( total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg @@ -765,7 +764,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): Dict: the pytorch form state_dict """ zero_state = dict() - device = get_current_device() + device = get_accelerator().get_current_device() for param, state in self.optim.state.items(): zero_state[param] = copy.deepcopy(state) for k, v in state.items(): @@ -827,7 +826,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ret_block = dict() ret_block_size = 0 - device = get_current_device() + device = get_accelerator().get_current_device() local_states = self.optim.state_dict()["state"] for param_idx, states in local_states.items(): current_block_size = 0 diff --git a/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md b/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md index 7a0e3b1a0..e87eafb6e 100644 --- a/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md +++ b/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md @@ -45,7 +45,6 @@ from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device ``` ## Define Plugin Create a `HybridParallelPlugin` object and specify the desired parallelism strategies to be used. In this example, both pipeline parallelism and ZeRO-1 are used simultaneously. @@ -149,7 +148,7 @@ model, optimizer, _criterion, _, lr_scheduler = booster.boost( ## Training GPT-2 using hybrid parallelism -In the previous tutorial, We've explained how to inject various parallelism features into the model and its training components using the Booster and `HybridParallelPlugin`. Now we can start model training. +In the previous tutorial, We've explained how to inject various parallelism features into the model and its training components using the Booster and `HybridParallelPlugin`. Now we can start model training. Define a training function. When pipeline parallelism is used, you need to call `booster.execute_pipeline` to schedule the stages of model training. ```python def train_epoch( @@ -204,4 +203,4 @@ Training the gpt-2 model for epoch in range(NUM_EPOCHS): train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator) ``` - \ No newline at end of file + diff --git a/docs/source/en/basics/booster_api.md b/docs/source/en/basics/booster_api.md index 4d7ffe5a4..2c75dd9ac 100644 --- a/docs/source/en/basics/booster_api.md +++ b/docs/source/en/basics/booster_api.md @@ -32,7 +32,7 @@ Plugin is an important component that manages parallel configuration (eg: The ge More details about usages of each plugin can be found in chapter [Booster Plugins](./booster_plugins.md). -Some plugins support lazy initialization, which can be used to save memory when initializating large models. For more details, please see [Lazy Initialization](../features/lazy_init.md). +Some plugins support lazy initialization, which can be used to save memory when initializing large models. For more details, please see [Lazy Initialization](../features/lazy_init.md). ### API of booster diff --git a/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md b/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md index 117406980..ae941b489 100644 --- a/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md +++ b/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md @@ -43,7 +43,6 @@ from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device ``` ### 定义plugin 定义一个[`HybridParallelPlugin`](../basics/booster_plugins.md)对象,指定所需要使用的并行策略,在该例子中,同时使用了流水线并行和zero1. @@ -201,4 +200,4 @@ def train_epoch( for epoch in range(NUM_EPOCHS): train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator) ``` - \ No newline at end of file + diff --git a/examples/community/roberta/pretraining/run_pretraining.py b/examples/community/roberta/pretraining/run_pretraining.py index 5396de693..40b11d649 100644 --- a/examples/community/roberta/pretraining/run_pretraining.py +++ b/examples/community/roberta/pretraining/run_pretraining.py @@ -16,10 +16,10 @@ from utils.global_vars import get_tensorboard_writer, get_timers, set_global_var from utils.logger import Logger import colossalai +from colossalai.accelerator import get_accelerator from colossalai.context import ParallelMode from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper from colossalai.tensor import ProcessGroup, ShardSpec -from colossalai.utils import get_current_device from colossalai.utils.model.colo_init_context import ColoInitContext @@ -53,7 +53,7 @@ def main(): set_global_variables(launch_time, args.tensorboard_path) world_size = torch.distributed.get_world_size() - get_current_device() + get_accelerator().get_current_device() # build model, optimizer and criterion if args.distplan.startswith("CAI"): @@ -67,7 +67,10 @@ def main(): # build GPT model with ColoInitContext( - device=get_current_device(), dtype=torch.half, default_dist_spec=default_dist_spec, default_pg=shard_pg + device=get_accelerator().get_current_device(), + dtype=torch.half, + default_dist_spec=default_dist_spec, + default_pg=shard_pg, ): config, model, numel = get_model(args, logger) @@ -78,7 +81,7 @@ def main(): elif args.distplan == "CAI_Gemini": gemini_config = dict( strict_ddp_mode=args.tp_degree == 1, - device=get_current_device(), + device=get_accelerator().get_current_device(), placement_policy=args.placement, pin_memory=True, hidden_dim=model.config.hidden_size, diff --git a/examples/images/dreambooth/train_dreambooth_colossalai.py b/examples/images/dreambooth/train_dreambooth_colossalai.py index 1a7f8da7f..cc2b2ebc7 100644 --- a/examples/images/dreambooth/train_dreambooth_colossalai.py +++ b/examples/images/dreambooth/train_dreambooth_colossalai.py @@ -20,11 +20,11 @@ from tqdm.auto import tqdm from transformers import AutoTokenizer, PretrainedConfig import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device disable_existing_loggers() logger = get_dist_logger() @@ -386,7 +386,7 @@ def main(args): cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < args.num_class_images: - torch_dtype = torch.float16 if get_current_device() == "cuda" else torch.float32 + torch_dtype = torch.float16 if get_accelerator().get_current_device() == "cuda" else torch.float32 pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=torch_dtype, @@ -401,7 +401,7 @@ def main(args): sample_dataset = PromptDataset(args.class_prompt, num_new_images) sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) - pipeline.to(get_current_device()) + pipeline.to(get_accelerator().get_current_device()) for example in tqdm( sample_dataloader, @@ -578,8 +578,8 @@ def main(args): # Move text_encode and vae to gpu. # For mixed precision training we cast the text_encoder and vae weights to half-precision # as these models are only used for inference, keeping weights in full precision is not required. - vae.to(get_current_device(), dtype=weight_dtype) - text_encoder.to(get_current_device(), dtype=weight_dtype) + vae.to(get_accelerator().get_current_device(), dtype=weight_dtype) + text_encoder.to(get_accelerator().get_current_device(), dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader)) @@ -613,7 +613,7 @@ def main(args): torch.cuda.reset_peak_memory_stats() # Move batch to gpu for key, value in batch.items(): - batch[key] = value.to(get_current_device(), non_blocking=True) + batch[key] = value.to(get_accelerator().get_current_device(), non_blocking=True) # Convert images to latent space optimizer.zero_grad() diff --git a/examples/images/dreambooth/train_dreambooth_colossalai_lora.py b/examples/images/dreambooth/train_dreambooth_colossalai_lora.py index ea6dde8bb..227488abe 100644 --- a/examples/images/dreambooth/train_dreambooth_colossalai_lora.py +++ b/examples/images/dreambooth/train_dreambooth_colossalai_lora.py @@ -21,13 +21,13 @@ from tqdm.auto import tqdm from transformers import AutoTokenizer, PretrainedConfig import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device disable_existing_loggers() logger = get_dist_logger() @@ -385,7 +385,7 @@ def main(args): cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < args.num_class_images: - torch_dtype = torch.float16 if get_current_device() == "cuda" else torch.float32 + torch_dtype = torch.float16 if get_accelerator().get_current_device() == "cuda" else torch.float32 pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=torch_dtype, @@ -400,7 +400,7 @@ def main(args): sample_dataset = PromptDataset(args.class_prompt, num_new_images) sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) - pipeline.to(get_current_device()) + pipeline.to(get_accelerator().get_current_device()) for example in tqdm( sample_dataloader, @@ -598,8 +598,8 @@ def main(args): # Move text_encode and vae to gpu. # For mixed precision training we cast the text_encoder and vae weights to half-precision # as these models are only used for inference, keeping weights in full precision is not required. - vae.to(get_current_device(), dtype=weight_dtype) - text_encoder.to(get_current_device(), dtype=weight_dtype) + vae.to(get_accelerator().get_current_device(), dtype=weight_dtype) + text_encoder.to(get_accelerator().get_current_device(), dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader)) @@ -633,7 +633,7 @@ def main(args): torch.cuda.reset_peak_memory_stats() # Move batch to gpu for key, value in batch.items(): - batch[key] = value.to(get_current_device(), non_blocking=True) + batch[key] = value.to(get_accelerator().get_current_device(), non_blocking=True) # Convert images to latent space optimizer.zero_grad() diff --git a/examples/images/resnet/train.py b/examples/images/resnet/train.py index 13df516d4..5871bbf87 100644 --- a/examples/images/resnet/train.py +++ b/examples/images/resnet/train.py @@ -13,12 +13,12 @@ from torch.utils.data import DataLoader from tqdm import tqdm import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin.dp_plugin_base import DPPluginBase from colossalai.cluster import DistCoordinator from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device # ============================== # Prepare Hyperparameters @@ -53,8 +53,8 @@ def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPl @torch.no_grad() def evaluate(model: nn.Module, test_dataloader: DataLoader, coordinator: DistCoordinator) -> float: model.eval() - correct = torch.zeros(1, dtype=torch.int64, device=get_current_device()) - total = torch.zeros(1, dtype=torch.int64, device=get_current_device()) + correct = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device()) + total = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device()) for images, labels in test_dataloader: images = images.cuda() labels = labels.cuda() diff --git a/examples/images/vit/vit_benchmark.py b/examples/images/vit/vit_benchmark.py index b770bc9cf..078017324 100644 --- a/examples/images/vit/vit_benchmark.py +++ b/examples/images/vit/vit_benchmark.py @@ -33,9 +33,10 @@ def get_data_batch(batch_size, num_labels, num_channels=3, height=224, width=224 def colo_memory_cap(size_in_GB): - from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device + from colossalai.accelerator import get_accelerator + from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction - cuda_capacity = colo_device_memory_capacity(get_current_device()) + cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device()) if size_in_GB * (1024**3) < cuda_capacity: colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity) print(f"Limiting GPU memory usage to {size_in_GB} GB") diff --git a/examples/inference/benchmark_llama.py b/examples/inference/benchmark_llama.py index 772fe2200..c49d98982 100644 --- a/examples/inference/benchmark_llama.py +++ b/examples/inference/benchmark_llama.py @@ -8,11 +8,9 @@ import transformers from transformers import AutoTokenizer, GenerationConfig import colossalai -import colossalai.utils.device as device_utils -from colossalai.inference.config import InferenceConfig -from colossalai.inference.core.engine import InferenceEngine +from colossalai.accelerator import get_accelerator +from colossalai.inference import InferenceEngine from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn -from colossalai.utils.device import get_current_device GIGABYTE = 1024**3 MEGABYTE = 1024 * 1024 @@ -55,7 +53,7 @@ CONFIG_MAP = { def data_gen(batch_size: int = 4, seq_len: int = 512): - input_ids = torch.randint(10, 30000, (batch_size, seq_len), device=get_current_device()) + input_ids = torch.randint(10, 30000, (batch_size, seq_len), device=get_accelerator().get_current_device()) return input_ids @@ -78,9 +76,9 @@ def print_details_info(model_config, args, whole_end2end): msg += f"Flops: {num_parameters * num_bytes / whole_avg_latency / 1e12:.2f} TFLOPS\n" if torch.cuda.is_available(): - msg += f"-------Memory Summary Device:{device_utils.current_device()}-------\n" - msg += f"Max memory allocated: {device_utils.max_memory_allocated() / GIGABYTE:.2f} GB\n" - msg += f"Max memory reserved: {device_utils.max_memory_reserved() / GIGABYTE:.2f} GB\n" + msg += f"-------Memory Summary Device:{get_accelerator().current_device()}-------\n" + msg += f"Max memory allocated: {get_accelerator().max_memory_allocated() / GIGABYTE:.2f} GB\n" + msg += f"Max memory reserved: {get_accelerator().max_memory_reserved() / GIGABYTE:.2f} GB\n" print(msg) diff --git a/examples/inference/run_llama_inference.py b/examples/inference/run_llama_inference.py index 8f85a9363..b5228c64e 100644 --- a/examples/inference/run_llama_inference.py +++ b/examples/inference/run_llama_inference.py @@ -5,9 +5,9 @@ import torch.distributed as dist from transformers import LlamaForCausalLM, LlamaTokenizer import colossalai +from colossalai.accelerator import get_accelerator from colossalai.inference import InferenceEngine from colossalai.testing import spawn -from colossalai.utils.device import get_current_device INPUT_TEXTS = [ "What is the longest river in the world?", @@ -57,7 +57,7 @@ def run_inference(args): ) inputs = tokenizer(INPUT_TEXTS, return_tensors="pt", padding="longest", max_length=max_input_len, truncation=True) - inputs = {k: v.to(get_current_device()) for k, v in inputs.items()} + inputs = {k: v.to(get_accelerator().get_current_device()) for k, v in inputs.items()} outputs = engine.generate(inputs) if rank == 0: diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py index aad12c9c2..0b1e77fff 100644 --- a/examples/language/bert/finetune.py +++ b/examples/language/bert/finetune.py @@ -18,11 +18,11 @@ from transformers import ( ) import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device # ============================== # Prepare Hyperparameters @@ -59,7 +59,7 @@ def evaluate_model( use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage(ignore_chunk=True) - accum_loss = torch.zeros(1, device=get_current_device()) + accum_loss = torch.zeros(1, device=get_accelerator().get_current_device()) for batch in dataloader: batch = move_to_cuda(batch) labels = batch["labels"] @@ -89,8 +89,10 @@ def evaluate_model( object_list = [None, None] dist.broadcast_object_list(object_list, src=current_pp_group_ranks[-1], group=pp_group) - metric.add_batch(predictions=object_list[0].to(get_current_device()), references=labels) - accum_loss.add_(object_list[1].to(get_current_device())) + metric.add_batch( + predictions=object_list[0].to(get_accelerator().get_current_device()), references=labels + ) + accum_loss.add_(object_list[1].to(get_accelerator().get_current_device())) else: batch = move_to_cuda(batch) diff --git a/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py b/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py index e811e1acb..b35112498 100644 --- a/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py +++ b/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py @@ -7,13 +7,13 @@ from model_zoo import GPTLMLoss, get_gpt2_components from torch.utils._pytree import tree_map import colossalai +from colossalai.accelerator import get_accelerator from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer from colossalai.auto_parallel.offload.mem_optimize import memory_optimize from colossalai.auto_parallel.offload.solver import NOT_NVML from colossalai.fx.profiler import parameter_size from colossalai.nn.optimizer import HybridAdam from colossalai.testing import spawn -from colossalai.utils import get_current_device def parse_args(): @@ -41,7 +41,7 @@ def train_gpt(args): 64, 8, ), - device=get_current_device(), + device=get_accelerator().get_current_device(), ) criterion = GPTLMLoss() diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py index 88b76c654..78d090ba2 100644 --- a/examples/language/gpt/gemini/train_gpt_demo.py +++ b/examples/language/gpt/gemini/train_gpt_demo.py @@ -12,12 +12,12 @@ from commons.utils import get_data, get_profile_context, get_tflops, get_time_st from packaging import version import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.lazy import LazyInitContext from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device CAI_VERSION = colossalai.__version__ @@ -141,7 +141,11 @@ def main(): criterion = GPTLMLoss() torch.manual_seed(123) if args.distplan.startswith("CAI"): - ctx = LazyInitContext(default_device=get_current_device()) if args.distplan == "CAI_Gemini" else nullcontext() + ctx = ( + LazyInitContext(default_device=get_accelerator().get_current_device()) + if args.distplan == "CAI_Gemini" + else nullcontext() + ) # build GPT model with ctx: model = model_builder(args.model_type)(checkpoint=True) diff --git a/examples/language/gpt/hybridparallelism/finetune.py b/examples/language/gpt/hybridparallelism/finetune.py index 62804eff8..eb56ee530 100644 --- a/examples/language/gpt/hybridparallelism/finetune.py +++ b/examples/language/gpt/hybridparallelism/finetune.py @@ -13,11 +13,11 @@ from tqdm import tqdm from transformers import AutoConfig, GPT2ForSequenceClassification, get_linear_schedule_with_warmup import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device # ============================== # Prepare Hyperparameters @@ -54,7 +54,7 @@ def evaluate_model( use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() - accum_loss = torch.zeros(1, device=get_current_device()) + accum_loss = torch.zeros(1, device=get_accelerator().get_current_device()) for batch in dataloader: batch = move_to_cuda(batch) labels = batch["labels"] @@ -83,8 +83,10 @@ def evaluate_model( object_list = [None, None] dist.broadcast_object_list(object_list, src=current_pp_group_ranks[-1], group=pp_group) - metric.add_batch(predictions=object_list[0].to(get_current_device()), references=labels) - accum_loss.add_(object_list[1].to(get_current_device())) + metric.add_batch( + predictions=object_list[0].to(get_accelerator().get_current_device()), references=labels + ) + accum_loss.add_(object_list[1].to(get_accelerator().get_current_device())) else: batch = move_to_cuda(batch) diff --git a/examples/language/gpt/titans/model/embed.py b/examples/language/gpt/titans/model/embed.py index b2e3f71a5..ec3df50c4 100644 --- a/examples/language/gpt/titans/model/embed.py +++ b/examples/language/gpt/titans/model/embed.py @@ -5,6 +5,7 @@ from torch import nn as nn from torch.nn import functional as F from torch.nn.parameter import Parameter +from colossalai.accelerator import get_accelerator from colossalai.legacy.context import ParallelMode, seed from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.base_layer import ParallelLayer @@ -12,7 +13,6 @@ from colossalai.legacy.nn.layer.parallel_1d._utils import gather_forward_split_b from colossalai.legacy.nn.layer.parallel_1d.layers import Linear1D_Row from colossalai.legacy.nn.layer.utils import divide from colossalai.legacy.registry import LAYERS, LOSSES -from colossalai.utils import get_current_device class VocabParallelEmbedding(torch.nn.Module): @@ -96,7 +96,9 @@ class VocabParallelEmbedding(torch.nn.Module): if position_ids is not None: position_ids = position_ids.view(-1, input_shape[-1]) if position_ids is None: - position_ids = torch.arange(0, input_shape[-1] + 0, dtype=torch.long, device=get_current_device()) + position_ids = torch.arange( + 0, input_shape[-1] + 0, dtype=torch.long, device=get_accelerator().get_current_device() + ) position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) position_embeddings = self.position_embeddings(position_ids) @@ -194,7 +196,7 @@ class VocabParallelEmbedding1D(torch.nn.Module): self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index # Allocate weights and initialize. - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} self.weight = Parameter(torch.empty(self.num_embeddings_per_partition, self.embedding_dim, **factory_kwargs)) init.uniform_(self.weight, -1, 1) @@ -439,7 +441,9 @@ class HiddenParallelEmbedding(torch.nn.Module): if position_ids is not None: position_ids = position_ids.view(-1, input_shape[-1]) if position_ids is None: - position_ids = torch.arange(0, input_shape[-1] + 0, dtype=torch.long, device=get_current_device()) + position_ids = torch.arange( + 0, input_shape[-1] + 0, dtype=torch.long, device=get_accelerator().get_current_device() + ) position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) position_embeddings = self.position_embeddings(position_ids) @@ -532,7 +536,7 @@ class HiddenParallelEmbedding1D(torch.nn.Module): self._weight = None # Allocate weights and initialize. - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} self.weight = Parameter(torch.empty(num_embeddings, embed_dim_per_partition, **factory_kwargs)) init.uniform_(self.weight, -1, 1) diff --git a/examples/language/llama2/benchmark.py b/examples/language/llama2/benchmark.py index a4c29b7c8..b8f70ce9c 100644 --- a/examples/language/llama2/benchmark.py +++ b/examples/language/llama2/benchmark.py @@ -13,13 +13,12 @@ from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import LlamaForCausalLM import colossalai -import colossalai.utils.device as device_utils +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchFSDPPlugin from colossalai.cluster import DistCoordinator from colossalai.lazy import LazyInitContext from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device # ============================== # Constants @@ -74,8 +73,8 @@ def main(): parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size") parser.add_argument("--extra_dp", type=int, default=1, help="Extra data parallel size, used for Gemini") parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size") - parser.add_argument("--mbs", type=int, default=1) - parser.add_argument("--zero", type=int, default=0) + parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel") + parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled") args = parser.parse_args() colossalai.launch_from_torch({}) @@ -98,7 +97,13 @@ def main(): extra_dp_size=args.extra_dp, ) elif args.plugin == "gemini_auto": - plugin = GeminiPlugin(placement_policy="auto", precision="bf16", warmup_non_model_data_ratio=args.warmup_ratio, tp_size=args.tp, extra_dp_size=args.extra_dp) + plugin = GeminiPlugin( + placement_policy="auto", + precision="bf16", + warmup_non_model_data_ratio=args.warmup_ratio, + tp_size=args.tp, + extra_dp_size=args.extra_dp, + ) elif args.plugin == "fsdp": if use_empty_init: plugin = TorchFSDPPlugin( @@ -137,7 +142,7 @@ def main(): zero_stage=args.zero, num_model_chunks=2, enable_fused_normalization=torch.cuda.is_available(), - num_microbatches=args.mbs, + microbatch_size=args.mbs, precision="bf16", ) elif args.plugin == "3d_cpu": @@ -147,7 +152,7 @@ def main(): zero_stage=args.zero, cpu_offload=True, enable_fused_normalization=torch.cuda.is_available(), - num_microbatches=args.mbs, + microbatch_size=args.mbs, initial_scale=2**8, precision="bf16", ) @@ -171,7 +176,7 @@ def main(): # Initialize Model and Optimizer # ============================== init_ctx = ( - LazyInitContext(default_device=get_current_device()) + LazyInitContext(default_device=get_accelerator().get_current_device()) if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin)) else nullcontext() ) @@ -202,7 +207,9 @@ def main(): torch.set_default_dtype(torch.bfloat16) model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader) torch.set_default_dtype(torch.float) - coordinator.print_on_master(f"Booster init max CUDA memory: {device_utils.max_memory_allocated()/1024**2:.2f} MB") + coordinator.print_on_master( + f"Booster init max CUDA memory: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB" + ) coordinator.print_on_master( f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB" ) @@ -228,7 +235,7 @@ def main(): performance_evaluator.on_step_end(**batch) performance_evaluator.on_fit_end() - coordinator.print_on_master(f"Max CUDA memory usage: {device_utils.max_memory_allocated()/1024**2:.2f} MB") + coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB") if __name__ == "__main__": diff --git a/examples/language/llama2/data_utils.py b/examples/language/llama2/data_utils.py index a438833e1..6b9e8ef28 100644 --- a/examples/language/llama2/data_utils.py +++ b/examples/language/llama2/data_utils.py @@ -8,7 +8,7 @@ from torch.distributed import ProcessGroup from torch.distributed.distributed_c10d import _get_default_group from torch.utils.data import DataLoader, Dataset, DistributedSampler -from colossalai.utils import get_current_device +from colossalai.accelerator import get_accelerator class StatefulDistributedSampler(DistributedSampler): @@ -108,7 +108,9 @@ class RandomDataset(Dataset): def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000): self.num_samples = num_samples self.max_length = max_length - self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device()) + self.input_ids = torch.randint( + 0, vocab_size, (num_samples, max_length), device=get_accelerator().get_current_device() + ) self.attention_mask = torch.ones_like(self.input_ids) def __len__(self): diff --git a/examples/language/llama2/finetune.py b/examples/language/llama2/finetune.py index f7708b1a3..66b540076 100644 --- a/examples/language/llama2/finetune.py +++ b/examples/language/llama2/finetune.py @@ -21,13 +21,13 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM from transformers.models.llama.tokenization_llama import LlamaTokenizer import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin from colossalai.cluster import DistCoordinator from colossalai.lazy import LazyInitContext from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device def get_model_numel(model: nn.Module) -> int: @@ -191,7 +191,9 @@ def main(): config = LlamaConfig.from_pretrained(args.model_path) # use lazy init when using GeminiPlugin init_ctx = ( - LazyInitContext(default_device=get_current_device()) if isinstance(plugin, GeminiPlugin) else nullcontext() + LazyInitContext(default_device=get_accelerator().get_current_device()) + if isinstance(plugin, GeminiPlugin) + else nullcontext() ) with init_ctx: diff --git a/examples/language/llama2/performance_evaluator.py b/examples/language/llama2/performance_evaluator.py index 6b1c92711..c2169a730 100644 --- a/examples/language/llama2/performance_evaluator.py +++ b/examples/language/llama2/performance_evaluator.py @@ -5,9 +5,8 @@ import torch import torch.distributed as dist from torch import Tensor -import colossalai.utils.device as device_utils +from colossalai.accelerator import get_accelerator from colossalai.cluster import DistCoordinator -from colossalai.utils.device import get_current_device def divide(x: float, y: float) -> float: @@ -22,7 +21,7 @@ def divide(x: float, y: float) -> float: def all_reduce_mean(x: float, world_size: int) -> float: if world_size == 1: return x - tensor = torch.tensor([x], device=get_current_device()) + tensor = torch.tensor([x], device=get_accelerator().get_current_device()) dist.all_reduce(tensor) tensor = tensor / world_size return tensor.item() @@ -86,13 +85,13 @@ class PerformanceEvaluator: self.disable = self.ignore_steps > 0 and step < self.ignore_steps if self.disable: return - device_utils.synchronize() + get_accelerator().synchronize() self.timer.start() def on_step_end(self, input_ids: Tensor, **kwargs) -> None: if self.disable: return - device_utils.synchronize() + get_accelerator().synchronize() self.timer.end() batch_size, seq_len = input_ids.shape diff --git a/examples/language/llama2/pretrain.py b/examples/language/llama2/pretrain.py index bb10f7a00..4cdf93e19 100644 --- a/examples/language/llama2/pretrain.py +++ b/examples/language/llama2/pretrain.py @@ -20,13 +20,13 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM from transformers.models.llama.tokenization_llama import LlamaTokenizer import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin from colossalai.cluster import DistCoordinator from colossalai.lazy import LazyInitContext from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device MODEL_CONFIGS = { "7b": LlamaConfig(max_position_embeddings=4096), @@ -227,7 +227,9 @@ def main(): config = MODEL_CONFIGS[args.config] # use lazy init when using GeminiPlugin init_ctx = ( - LazyInitContext(default_device=get_current_device()) if isinstance(plugin, GeminiPlugin) else nullcontext() + LazyInitContext(default_device=get_accelerator().get_current_device()) + if isinstance(plugin, GeminiPlugin) + else nullcontext() ) with init_ctx: @@ -273,11 +275,10 @@ def main(): dataloader.sampler.set_start_index(sampler_start_idx) for epoch in range(start_epoch, args.num_epochs): dataloader.sampler.set_epoch(epoch) - step_nums = num_steps_per_epoch - start_step dataloader_iter = iter(dataloader) with tqdm( - range(step_nums), + range(start_step, num_steps_per_epoch), desc=f"Epoch {epoch}", disable=not print_flag, total=num_steps_per_epoch, diff --git a/examples/language/llama2/scripts/benchmark_70B/3d.sh b/examples/language/llama2/scripts/benchmark_70B/3d.sh index d50c57042..cb8f218fa 100644 --- a/examples/language/llama2/scripts/benchmark_70B/3d.sh +++ b/examples/language/llama2/scripts/benchmark_70B/3d.sh @@ -14,4 +14,4 @@ cd ../.. export OMP_NUM_THREADS=8 -colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -c 70b -p 3d -g -x -b 8 --tp 4 --pp 2 --mbs 4 +colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -c 70b -p 3d -g -x -b 8 --tp 4 --pp 2 --mbs 1 diff --git a/examples/language/openmoe/benchmark/benchmark_cai.py b/examples/language/openmoe/benchmark/benchmark_cai.py index 65562b386..03b660ecf 100644 --- a/examples/language/openmoe/benchmark/benchmark_cai.py +++ b/examples/language/openmoe/benchmark/benchmark_cai.py @@ -14,6 +14,7 @@ from transformers.models.llama import LlamaConfig from utils import PerformanceEvaluator, get_model_numel import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.cluster import DistCoordinator @@ -21,7 +22,6 @@ from colossalai.moe.layers import apply_load_balance from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import skip_init from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device def move_to_cuda(batch, device): @@ -64,13 +64,15 @@ class RandomDataset(Dataset): ) self.input_ids.append(encode["input_ids"]) self.attention_mask.append(encode["attention_mask"]) - self.input_ids = torch.cat(self.input_ids, dim=0).to(get_current_device()) - self.attention_mask = torch.cat(self.attention_mask, dim=0).to(get_current_device()) + self.input_ids = torch.cat(self.input_ids, dim=0).to(get_accelerator().get_current_device()) + self.attention_mask = torch.cat(self.attention_mask, dim=0).to(get_accelerator().get_current_device()) repeat_times = num_samples // self.input_ids.shape[0] + 1 self.input_ids = self.input_ids.repeat(repeat_times, 1)[:num_samples] self.attention_mask = self.attention_mask.repeat(repeat_times, 1)[:num_samples] else: - self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device()) + self.input_ids = torch.randint( + 0, vocab_size, (num_samples, max_length), device=get_accelerator().get_current_device() + ) self.attention_mask = torch.ones_like(self.input_ids) def __len__(self): diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index ec7644317..eee3b505a 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -35,7 +35,7 @@ from transformers.utils import ( replace_return_docstrings, ) -from colossalai.kernel.cuda_native.mha.flash_attn_2 import HAS_FLASH_ATTN +from colossalai.kernel.extensions.flash_attention import HAS_FLASH_ATTN from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON from colossalai.moe.layers import SparseMLP from colossalai.moe.manager import MOE_MANAGER diff --git a/examples/language/openmoe/model/openmoe_policy.py b/examples/language/openmoe/model/openmoe_policy.py index f354bbea9..17e7aa46c 100644 --- a/examples/language/openmoe/model/openmoe_policy.py +++ b/examples/language/openmoe/model/openmoe_policy.py @@ -43,7 +43,7 @@ class OpenMoePolicy(Policy): if self.shard_config.enable_sequence_parallelism: self.shard_config.enable_sequence_parallelism = False raise NotImplementedError( - "openmoe dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + "openmoe doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") if self.shard_config.enable_tensor_parallelism: raise NotImplementedError("Tensor parallelism is not supported for openmoe model now.") diff --git a/examples/language/openmoe/train.py b/examples/language/openmoe/train.py index b08436166..1ae661f54 100644 --- a/examples/language/openmoe/train.py +++ b/examples/language/openmoe/train.py @@ -15,6 +15,7 @@ from transformers import T5Tokenizer from transformers.models.llama import LlamaConfig import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.cluster import DistCoordinator @@ -22,7 +23,6 @@ from colossalai.moe.layers import apply_load_balance from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import skip_init from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device def move_to_cuda(batch, device): @@ -61,7 +61,9 @@ class RandomDataset(Dataset): def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000, tokenizer=None): self.num_samples = num_samples self.max_length = max_length - self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device()) + self.input_ids = torch.randint( + 0, vocab_size, (num_samples, max_length), device=get_accelerator().get_current_device() + ) self.attention_mask = torch.ones_like(self.input_ids) def __len__(self): diff --git a/examples/language/palm/train.py b/examples/language/palm/train.py index 7af02e24e..4fac7b507 100644 --- a/examples/language/palm/train.py +++ b/examples/language/palm/train.py @@ -14,12 +14,12 @@ from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper from torch.utils.data import DataLoader, Dataset import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.lazy import LazyInitContext from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn import HybridAdam -from colossalai.utils import get_current_device # constants @@ -159,7 +159,11 @@ if args.distplan == "colossalai": logger.info(f"plugin: {plugin}") booster = Booster(plugin=plugin, **booster_kwargs) - ctx = LazyInitContext(default_device=get_current_device()) if args.plugin == "gemini" else nullcontext() + ctx = ( + LazyInitContext(default_device=get_accelerator().get_current_device()) + if args.plugin == "gemini" + else nullcontext() + ) with ctx: model = PaLM(num_tokens=50304, dim=4096, depth=64) diff --git a/examples/tutorial/auto_parallel/README.md b/examples/tutorial/auto_parallel/README.md index 135615676..6f11298fc 100644 --- a/examples/tutorial/auto_parallel/README.md +++ b/examples/tutorial/auto_parallel/README.md @@ -49,7 +49,7 @@ You should expect to the log like this. This log shows the edge cost on the comp ### Auto-Checkpoint Tutorial -We prepare two bechmarks for you to test the performance of auto checkpoint +We prepare two benchmarks for you to test the performance of auto checkpoint The first test `auto_ckpt_solver_test.py` will show you the ability of solver to search checkpoint strategy that could fit in the given budget (test on GPT2 Medium and ResNet 50). It will output the benchmark summary and data visualization of peak memory vs. budget memory and relative step time vs. peak memory. diff --git a/examples/tutorial/new_api/cifar_resnet/train.py b/examples/tutorial/new_api/cifar_resnet/train.py index 4407a51c3..a4733126f 100644 --- a/examples/tutorial/new_api/cifar_resnet/train.py +++ b/examples/tutorial/new_api/cifar_resnet/train.py @@ -13,12 +13,12 @@ from torch.utils.data import DataLoader from tqdm import tqdm import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin.dp_plugin_base import DPPluginBase from colossalai.cluster import DistCoordinator from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device # ============================== # Prepare Hyperparameters @@ -53,8 +53,8 @@ def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPl @torch.no_grad() def evaluate(model: nn.Module, test_dataloader: DataLoader, coordinator: DistCoordinator) -> float: model.eval() - correct = torch.zeros(1, dtype=torch.int64, device=get_current_device()) - total = torch.zeros(1, dtype=torch.int64, device=get_current_device()) + correct = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device()) + total = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device()) for images, labels in test_dataloader: images = images.cuda() labels = labels.cuda() diff --git a/examples/tutorial/new_api/cifar_vit/train.py b/examples/tutorial/new_api/cifar_vit/train.py index 700e4d2e0..ec6c852b5 100644 --- a/examples/tutorial/new_api/cifar_vit/train.py +++ b/examples/tutorial/new_api/cifar_vit/train.py @@ -13,13 +13,13 @@ from torch.utils.data import DataLoader from tqdm import tqdm import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin.dp_plugin_base import DPPluginBase from colossalai.cluster import DistCoordinator from colossalai.nn.lr_scheduler import LinearWarmupLR from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device # ============================== # Prepare Hyperparameters @@ -73,8 +73,8 @@ def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPl @torch.no_grad() def evaluate(model: nn.Module, test_dataloader: DataLoader, coordinator: DistCoordinator) -> float: model.eval() - correct = torch.zeros(1, dtype=torch.int64, device=get_current_device()) - total = torch.zeros(1, dtype=torch.int64, device=get_current_device()) + correct = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device()) + total = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device()) for images, labels in test_dataloader: images = images.cuda() labels = labels.cuda() diff --git a/examples/tutorial/new_api/glue_bert/finetune.py b/examples/tutorial/new_api/glue_bert/finetune.py index 990822c9f..e97c9017f 100644 --- a/examples/tutorial/new_api/glue_bert/finetune.py +++ b/examples/tutorial/new_api/glue_bert/finetune.py @@ -12,11 +12,11 @@ from tqdm import tqdm from transformers import AutoConfig, BertForSequenceClassification, get_linear_schedule_with_warmup import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device # ============================== # Prepare Hyperparameters @@ -45,7 +45,7 @@ def evaluate( model.eval() def evaluate_subset(dataloader: DataLoader): - accum_loss = torch.zeros(1, device=get_current_device()) + accum_loss = torch.zeros(1, device=get_accelerator().get_current_device()) for batch in dataloader: batch = move_to_cuda(batch) outputs = model(**batch) diff --git a/examples/tutorial/opt/opt/run_clm.py b/examples/tutorial/opt/opt/run_clm.py index 9bd23ffc8..3f0d04879 100755 --- a/examples/tutorial/opt/opt/run_clm.py +++ b/examples/tutorial/opt/opt/run_clm.py @@ -51,13 +51,13 @@ from transformers import ( from transformers.utils.versions import require_version import colossalai +from colossalai.accelerator import get_accelerator from colossalai.legacy.context import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.tensor import ProcessGroup from colossalai.legacy.utils import get_dataloader from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device from colossalai.zero import GeminiOptimizer require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") @@ -249,9 +249,9 @@ def parse_args(): def colo_memory_cap(size_in_GB): - from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device + from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction - cuda_capacity = colo_device_memory_capacity(get_current_device()) + cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device()) if size_in_GB * (1024**3) < cuda_capacity: colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity) print("Using {} GB of GPU memory".format(size_in_GB)) @@ -265,7 +265,9 @@ class DummyDataloader: self.vocab_size = vocab_size def generate(self): - input_ids = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len), device=get_current_device()) + input_ids = torch.randint( + 0, self.vocab_size, (self.batch_size, self.seq_len), device=get_accelerator().get_current_device() + ) attention_mask = torch.ones_like(input_ids) return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": input_ids} @@ -390,7 +392,7 @@ def main(): if args.init_in_cpu: init_dev = torch.device("cpu") else: - init_dev = get_current_device() + init_dev = get_accelerator().get_current_device() cai_version = colossalai.__version__ logger.info(f"using Colossal-AI version {cai_version}") @@ -439,7 +441,9 @@ def main(): except ImportError: # this works for unreleased main branch, and this may be released on 0.2.9 from colossalai.zero import GeminiDDP - model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True) + model = GeminiDDP( + model, device=get_accelerator().get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True + ) elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"): from colossalai.gemini import ChunkManager, GeminiManager diff --git a/examples/tutorial/sequence_parallel/model/bert.py b/examples/tutorial/sequence_parallel/model/bert.py index 7b0e93d95..64260374a 100644 --- a/examples/tutorial/sequence_parallel/model/bert.py +++ b/examples/tutorial/sequence_parallel/model/bert.py @@ -3,13 +3,13 @@ import inspect import torch import torch.nn as nn -from colossalai.kernel import LayerNorm from colossalai.legacy.context import ParallelMode from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.wrapper import PipelineSharedModuleWrapper from colossalai.legacy.pipeline.utils import partition_uniform from colossalai.logging import get_dist_logger +from colossalai.nn.layer.layernorm import MixedFusedLayerNorm as LayerNorm from .layers import BertDualHead, BertLayer, Embedding, PreProcessor, VocabEmbedding from .layers.init_method import init_normal, output_init_normal diff --git a/examples/tutorial/sequence_parallel/model/layers/head.py b/examples/tutorial/sequence_parallel/model/layers/head.py index 75afeee60..ff81ace39 100644 --- a/examples/tutorial/sequence_parallel/model/layers/head.py +++ b/examples/tutorial/sequence_parallel/model/layers/head.py @@ -3,9 +3,9 @@ import torch.nn as nn import torch.nn.functional as F from loss_func.cross_entropy import vocab_cross_entropy -from colossalai.kernel import LayerNorm from colossalai.legacy.context import ParallelMode from colossalai.legacy.core import global_context as gpc +from colossalai.nn.layer.layernorm import MixedFusedLayerNorm as LayerNorm from .linear import Linear from .pooler import Pooler diff --git a/examples/tutorial/sequence_parallel/train.py b/examples/tutorial/sequence_parallel/train.py index e9ceb8d70..f25fc8189 100644 --- a/examples/tutorial/sequence_parallel/train.py +++ b/examples/tutorial/sequence_parallel/train.py @@ -8,12 +8,12 @@ from lr_scheduler import AnnealingLR from model.bert import BertForPretrain, build_pipeline_bert import colossalai -from colossalai.kernel import LayerNorm from colossalai.legacy.amp import AMP_TYPE from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.utils import is_using_pp from colossalai.logging import get_dist_logger +from colossalai.nn.layer.layernorm import MixedFusedLayerNorm as LayerNorm from colossalai.nn.optimizer import FusedAdam from colossalai.utils import MultiTimer diff --git a/extensions/README.md b/extensions/README.md new file mode 100644 index 000000000..6f5feb55c --- /dev/null +++ b/extensions/README.md @@ -0,0 +1,140 @@ +# 🔌 Extensions + +## 📌 Table of Contents + +- [🔌 Extensions](#-extensions) + - [📌 Table of Contents](#-table-of-contents) + - [📚 Introduction](#-introduction) + - [🪅 Design](#-design) + - [🛠 API Usage](#-api-usage) + - [🏗 Write a customized extension](#-write-a-customized-extension) + - [✏️ Acknowledgement](#️-acknowledgement) + +## 📚 Introduction + +This module is a designed to offer extensions to the existing ColossalAI framework. It is designed to be a collection of high-performance kernels to speed up the training and inference process. Different from writing an individual kernel, the `extensions` module offers a layer of abstraction to collate kernels written in different compiler backends and for different hardware backends in an organized way. Please see the design and usage in the sections below. + +## 🪅 Design + +The `extensions` module is a sub-module of the `colossalai.kernel` module. This module is put at the project root directory so that it can be imported for AOT (ahead-of-time) build. At the same time, it is symbolically linked at the `colossalai.kernel.extensions` path for runtime build. + +As we want to support multi-backend kernels, we have to consider multiple compiler options such as `torch.jit`, `CUDA`, `triton` and multiple hardware backends such as `CPU`, `GPU` and `NPU`. To make it easy for the users, we have abstract away the kernels into extensions and expose a single loader to the user for each kind of kernel. + +For example, if the user wants to use the CPU Adam kernel, he can just call `load()` on the kernel loader. The kernel loader will automatically select the correct extension based on the current hardware and compiler backend. The user does not need to worry about the details of the kernel implementation. For example, if the user is using ARM CPU, then Arm kernel will be built and loaded. If it is a X86 CPU, then it is the X86 kernel that will be loaded. + +```python +from colossalai.kernel.kernel_loader import CPUAdamLoader + +# load the kernel compatible with the current hardware +kernel = CPUAdamLoader().load() +``` + +![](https://github.com/hpcaitech/public_assets/blob/main/colossalai/img/extensions.png?raw=true) + +## 🛠 API Usage + +To make the `colossalai.kernel` easy to use, we expose some simple APIs and you can use them based on your scenario. + +- Case 1: Simply load a kernel + +```python +from colossalai.kernel.kernel_loader import CPUAdamLoader + +# load the kernel compatible with the current hardware +kernel = CPUAdamLoader().load() +``` + +- Case 2: Load a specific kernel + +This case applies if you are familiar with the extensions available. + +```python +from colossalai.kernel.kernel_loader import CPUAdamLoader + +# load the kernel by giving the kernel name +kernel = CPUAdamLoader().load(ext_name="cpu_adam_arm") +``` + +- Case 3: Register your own extension + +This case applies if you know how to write an extension. If you do not know how, you can refer to the section below. + +```python +from colossalai.kernel.kernel_loader import CPUAdamLoader +from colossalai.kernel.base_extension import _Extension + +# create your own extension class +class MyExtension(_Extension): + + def __init__(self): + self._name = "my_extension" + self._support_aot = True + self._support_jit = True + self.priority = 10 + + # implementation here + ... + +# register your extension +# you can use the priority value to make sure your kernel will be loaded by default +CPUAdamLoader.register_extension(MyExtension) + +# load the kernel +kernel = CPUAdamLoader().load() +``` + +## 🏗 Write a customized extension + +It is easy to write a customized extension. If you have experience writing CUDA/triton kernels, you should get familiar with the process quickly. + +You just need to inherit the `_Extension` base class or other backend-specific classes such as `_CudaExtension` and implement the abstract methods. Then, you need to register your extension to the kernel loader based on the Case 3 above. The kernel loader will automatically select the correct extension based on the priority score, current hardware, compiler backend. + +```python +from colossalai.kernel.base_extension import _Extension + + +class MyExtension(_Extension): + + def __init__(self): + self._name = "my_extension" + self._support_aot = True + self._support_jit = True + self.priority = 10 + + def is_hardware_available(self) -> bool: + """ + Return if the required hardware can be found. + """ + ... + + def assert_hardware_compatible(self) -> None: + """ + Check if the hardware required by the kernel is compatible. + """ + ... + + def build_aot(self) -> Union["CppExtension", "CUDAExtension"]: + """ + If this kernel can be built AOT, it should return an extension object + to Python setuptools for compilation. + """ + ... + + def build_jit(self) -> Callable: + """ + Build extension kernel just in time. + """ + ... + + def load(self): + """ + The API called by the user to get the kernel. + """ + ... + +``` + +## ✏️ Acknowledgement + +This module is written from scratch but we learnt a lot by looking into [DeepSpeed' +s op_builder](https://github.com/microsoft/DeepSpeed/tree/master/op_builder). We wish to acknowledge their great work and contributions to the open-source community. diff --git a/extensions/__init__.py b/extensions/__init__.py new file mode 100644 index 000000000..9343cadda --- /dev/null +++ b/extensions/__init__.py @@ -0,0 +1,36 @@ +from .cpu_adam import CpuAdamArmExtension, CpuAdamX86Extension +from .flash_attention import ( + FlashAttentionDaoCudaExtension, + FlashAttentionNpuExtension, + FlashAttentionXformersCudaExtension, +) +from .layernorm import LayerNormCudaExtension +from .moe import MoeCudaExtension +from .optimizer import FusedOptimizerCudaExtension +from .softmax import ScaledMaskedSoftmaxCudaExtension, ScaledUpperTriangleMaskedSoftmaxCudaExtension + +ALL_EXTENSIONS = [ + CpuAdamArmExtension, + CpuAdamX86Extension, + LayerNormCudaExtension, + MoeCudaExtension, + FusedOptimizerCudaExtension, + ScaledMaskedSoftmaxCudaExtension, + ScaledUpperTriangleMaskedSoftmaxCudaExtension, + FlashAttentionDaoCudaExtension, + FlashAttentionXformersCudaExtension, + FlashAttentionNpuExtension, +] + +__all__ = [ + "CpuAdamArmExtension", + "CpuAdamX86Extension", + "LayerNormCudaExtension", + "MoeCudaExtension", + "FusedOptimizerCudaExtension", + "ScaledMaskedSoftmaxCudaExtension", + "ScaledUpperTriangleMaskedSoftmaxCudaExtension", + "FlashAttentionDaoCudaExtension", + "FlashAttentionXformersCudaExtension", + "FlashAttentionNpuExtension", +] diff --git a/extensions/base_extension.py b/extensions/base_extension.py new file mode 100644 index 000000000..c815a7f2a --- /dev/null +++ b/extensions/base_extension.py @@ -0,0 +1,82 @@ +import hashlib +import os +from abc import ABC, abstractmethod +from typing import Callable, Union + +__all__ = ["_Extension"] + + +class _Extension(ABC): + def __init__(self, name: str, support_aot: bool, support_jit: bool, priority: int = 1): + self._name = name + self._support_aot = support_aot + self._support_jit = support_jit + self.priority = priority + + @property + def name(self): + return self._name + + @property + def support_aot(self): + return self._support_aot + + @property + def support_jit(self): + return self._support_jit + + @staticmethod + def get_jit_extension_folder_path(): + """ + Kernels which are compiled during runtime will be stored in the same cache folder for reuse. + The folder is in the path ~/.cache/colossalai/torch_extensions/. + The name of the follows a common format: + torch._- + + The suffix is the hash value of the path of the `colossalai` file. + """ + import torch + + import colossalai + from colossalai.accelerator import get_accelerator + + # get torch version + torch_version_major = torch.__version__.split(".")[0] + torch_version_minor = torch.__version__.split(".")[1] + + # get device version + device_name = get_accelerator().name + device_version = get_accelerator().get_version() + + # use colossalai's file path as hash + hash_suffix = hashlib.sha256(colossalai.__file__.encode()).hexdigest() + + # concat + home_directory = os.path.expanduser("~") + extension_directory = f".cache/colossalai/torch_extensions/torch{torch_version_major}.{torch_version_minor}_{device_name}-{device_version}-{hash_suffix}" + cache_directory = os.path.join(home_directory, extension_directory) + return cache_directory + + @abstractmethod + def is_hardware_available(self) -> bool: + """ + Check if the hardware required by the kernel is available. + """ + + @abstractmethod + def assert_hardware_compatible(self) -> None: + """ + Check if the hardware required by the kernel is compatible. + """ + + @abstractmethod + def build_aot(self) -> Union["CppExtension", "CUDAExtension"]: + pass + + @abstractmethod + def build_jit(self) -> Callable: + pass + + @abstractmethod + def load(self) -> Callable: + pass diff --git a/extensions/cpp_extension.py b/extensions/cpp_extension.py new file mode 100644 index 000000000..b4c40c9f1 --- /dev/null +++ b/extensions/cpp_extension.py @@ -0,0 +1,134 @@ +import importlib +import os +import time +from abc import abstractmethod +from pathlib import Path +from typing import List + +from .base_extension import _Extension + +__all__ = ["_CppExtension"] + + +class _CppExtension(_Extension): + def __init__(self, name: str, priority: int = 1): + super().__init__(name, support_aot=True, support_jit=True, priority=priority) + + # we store the op as an attribute to avoid repeated building and loading + self.cached_op = None + + # build-related variables + self.prebuilt_module_path = "colossalai._C" + self.prebuilt_import_path = f"{self.prebuilt_module_path}.{self.name}" + self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] + + def csrc_abs_path(self, path): + return os.path.join(self.relative_to_abs_path("csrc"), path) + + def relative_to_abs_path(self, code_path: str) -> str: + """ + This function takes in a path relative to the colossalai root directory and return the absolute path. + """ + + # get the current file path + # iteratively check the parent directory + # if the parent directory is "extensions", then the current file path is the root directory + # otherwise, the current file path is inside the root directory + current_file_path = Path(__file__) + while True: + if current_file_path.name == "extensions": + break + else: + current_file_path = current_file_path.parent + extension_module_path = current_file_path + code_abs_path = extension_module_path.joinpath(code_path) + return str(code_abs_path) + + # functions must be overrided over + def strip_empty_entries(self, args): + """ + Drop any empty strings from the list of compile and link flags + """ + return [x for x in args if len(x) > 0] + + def import_op(self): + """ + This function will import the op module by its string name. + """ + return importlib.import_module(self.prebuilt_import_path) + + def build_aot(self) -> "CppExtension": + from torch.utils.cpp_extension import CppExtension + + return CppExtension( + name=self.prebuilt_import_path, + sources=self.strip_empty_entries(self.sources_files()), + include_dirs=self.strip_empty_entries(self.include_dirs()), + extra_compile_args=self.strip_empty_entries(self.cxx_flags()), + ) + + def build_jit(self) -> None: + from torch.utils.cpp_extension import load + + build_directory = _Extension.get_jit_extension_folder_path() + build_directory = Path(build_directory) + build_directory.mkdir(parents=True, exist_ok=True) + + # check if the kernel has been built + compiled_before = False + kernel_file_path = build_directory.joinpath(f"{self.name}.o") + if kernel_file_path.exists(): + compiled_before = True + + # load the kernel + if compiled_before: + print(f"[extension] Loading the JIT-built {self.name} kernel during runtime now") + else: + print(f"[extension] Compiling the JIT {self.name} kernel during runtime now") + + build_start = time.time() + op_kernel = load( + name=self.name, + sources=self.strip_empty_entries(self.sources_files()), + extra_include_paths=self.strip_empty_entries(self.include_dirs()), + extra_cflags=self.cxx_flags(), + extra_ldflags=[], + build_directory=str(build_directory), + ) + build_duration = time.time() - build_start + + if compiled_before: + print(f"[extension] Time taken to load {self.name} op: {build_duration} seconds") + else: + print(f"[extension] Time taken to compile {self.name} op: {build_duration} seconds") + + return op_kernel + + # functions must be overrided begin + @abstractmethod + def sources_files(self) -> List[str]: + """ + This function should return a list of source files for extensions. + """ + + @abstractmethod + def include_dirs(self) -> List[str]: + """ + This function should return a list of include files for extensions. + """ + + @abstractmethod + def cxx_flags(self) -> List[str]: + """ + This function should return a list of cxx compilation flags for extensions. + """ + + def load(self): + try: + op_kernel = self.import_op() + except ImportError: + # if import error occurs, it means that the kernel is not pre-built + # so we build it jit + op_kernel = self.build_jit() + + return op_kernel diff --git a/extensions/cpu_adam/__init__.py b/extensions/cpu_adam/__init__.py new file mode 100644 index 000000000..cfd26a6a2 --- /dev/null +++ b/extensions/cpu_adam/__init__.py @@ -0,0 +1,5 @@ +from .cpu_adam_arm import CpuAdamArmExtension +from .cpu_adam_x86 import CpuAdamX86Extension + +__all__ = ['CpuAdamArmExtension', 'CpuAdamX86Extension'] + diff --git a/extensions/cpu_adam/cpu_adam_arm.py b/extensions/cpu_adam/cpu_adam_arm.py new file mode 100644 index 000000000..35bff3b55 --- /dev/null +++ b/extensions/cpu_adam/cpu_adam_arm.py @@ -0,0 +1,41 @@ +import platform + +from ..cpp_extension import _CppExtension + + +class CpuAdamArmExtension(_CppExtension): + def __init__(self): + super().__init__(name="cpu_adam_arm") + + def is_hardware_available(self) -> bool: + # only arm allowed + return platform.machine() == "aarch64" + + def assert_hardware_compatible(self) -> None: + arch = platform.machine() + assert ( + arch == "aarch64" + ), f"[extension] The {self.name} kernel requires the CPU architecture to be aarch64 but got {arch}" + + # necessary 4 functions + def sources_files(self): + ret = [ + self.csrc_abs_path("arm/cpu_adam_arm.cpp"), + ] + return ret + + def include_dirs(self): + return [] + + def cxx_flags(self): + extra_cxx_flags = [ + "-std=c++14", + "-std=c++17", + "-g", + "-Wno-reorder", + "-fopenmp", + ] + return ["-O3"] + self.version_dependent_macros + extra_cxx_flags + + def nvcc_flags(self): + return [] diff --git a/op_builder/cpu_adam.py b/extensions/cpu_adam/cpu_adam_x86.py similarity index 60% rename from op_builder/cpu_adam.py rename to extensions/cpu_adam/cpu_adam_x86.py index 7988aae4b..a38194167 100644 --- a/op_builder/cpu_adam.py +++ b/extensions/cpu_adam/cpu_adam_x86.py @@ -1,19 +1,27 @@ -from .builder import Builder -from .utils import append_nvcc_threads +import platform + +from ..cuda_extension import _CudaExtension +from ..utils import append_nvcc_threads -class CPUAdamBuilder(Builder): - NAME = "cpu_adam" - PREBUILT_IMPORT_PATH = "colossalai._C.cpu_adam" - +class CpuAdamX86Extension(_CudaExtension): def __init__(self): - super().__init__(name=CPUAdamBuilder.NAME, prebuilt_import_path=CPUAdamBuilder.PREBUILT_IMPORT_PATH) - self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] + super().__init__(name="cpu_adam_x86") + + def is_hardware_available(self) -> bool: + return platform.machine() == "x86_64" and super().is_hardware_available() + + def assert_hardware_compatible(self) -> None: + arch = platform.machine() + assert ( + arch == "x86_64" + ), f"[extension] The {self.name} kernel requires the CPU architecture to be x86_64 but got {arch}" + super().assert_hardware_compatible() # necessary 4 functions def sources_files(self): ret = [ - self.csrc_abs_path("cpu_adam.cpp"), + self.csrc_abs_path("cuda/cpu_adam.cpp"), ] return ret diff --git a/colossalai/kernel/cuda_native/__init__.py b/extensions/csrc/__init__.py similarity index 86% rename from colossalai/kernel/cuda_native/__init__.py rename to extensions/csrc/__init__.py index f8a974b5f..0eac28d23 100644 --- a/colossalai/kernel/cuda_native/__init__.py +++ b/extensions/csrc/__init__.py @@ -1,5 +1,4 @@ from .layer_norm import MixedFusedLayerNorm as LayerNorm -from .mha.mha import ColoAttention from .multihead_attention import MultiHeadAttention from .scaled_softmax import AttnMaskType, FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax @@ -8,6 +7,5 @@ __all__ = [ "MultiHeadAttention", "FusedScaleMaskSoftmax", "ScaledUpperTriangMaskedSoftmax", - "ColoAttention", "AttnMaskType", ] diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam_arm.cpp b/extensions/csrc/arm/cpu_adam_arm.cpp similarity index 100% rename from colossalai/kernel/cuda_native/csrc/cpu_adam_arm.cpp rename to extensions/csrc/arm/cpu_adam_arm.cpp diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam_arm.h b/extensions/csrc/arm/cpu_adam_arm.h similarity index 100% rename from colossalai/kernel/cuda_native/csrc/cpu_adam_arm.h rename to extensions/csrc/arm/cpu_adam_arm.h diff --git a/colossalai/kernel/cuda_native/csrc/colossal_C_frontend.cpp b/extensions/csrc/cuda/colossal_C_frontend.cpp similarity index 100% rename from colossalai/kernel/cuda_native/csrc/colossal_C_frontend.cpp rename to extensions/csrc/cuda/colossal_C_frontend.cpp diff --git a/colossalai/kernel/cuda_native/csrc/compat.h b/extensions/csrc/cuda/compat.h similarity index 100% rename from colossalai/kernel/cuda_native/csrc/compat.h rename to extensions/csrc/cuda/compat.h diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp b/extensions/csrc/cuda/cpu_adam.cpp similarity index 100% rename from colossalai/kernel/cuda_native/csrc/cpu_adam.cpp rename to extensions/csrc/cuda/cpu_adam.cpp diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam.h b/extensions/csrc/cuda/cpu_adam.h similarity index 100% rename from colossalai/kernel/cuda_native/csrc/cpu_adam.h rename to extensions/csrc/cuda/cpu_adam.h diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/block_reduce.h b/extensions/csrc/cuda/include/block_reduce.h similarity index 100% rename from colossalai/kernel/cuda_native/csrc/kernels/include/block_reduce.h rename to extensions/csrc/cuda/include/block_reduce.h diff --git a/colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp b/extensions/csrc/cuda/layer_norm_cuda.cpp similarity index 100% rename from colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp rename to extensions/csrc/cuda/layer_norm_cuda.cpp diff --git a/colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu b/extensions/csrc/cuda/layer_norm_cuda_kernel.cu similarity index 100% rename from colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu rename to extensions/csrc/cuda/layer_norm_cuda_kernel.cu diff --git a/colossalai/kernel/cuda_native/csrc/moe_cuda.cpp b/extensions/csrc/cuda/moe_cuda.cpp similarity index 100% rename from colossalai/kernel/cuda_native/csrc/moe_cuda.cpp rename to extensions/csrc/cuda/moe_cuda.cpp diff --git a/colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu b/extensions/csrc/cuda/moe_cuda_kernel.cu similarity index 100% rename from colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu rename to extensions/csrc/cuda/moe_cuda_kernel.cu diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_adam.cu b/extensions/csrc/cuda/multi_tensor_adam.cu similarity index 100% rename from colossalai/kernel/cuda_native/csrc/multi_tensor_adam.cu rename to extensions/csrc/cuda/multi_tensor_adam.cu diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_apply.cuh b/extensions/csrc/cuda/multi_tensor_apply.cuh similarity index 100% rename from colossalai/kernel/cuda_native/csrc/multi_tensor_apply.cuh rename to extensions/csrc/cuda/multi_tensor_apply.cuh diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu b/extensions/csrc/cuda/multi_tensor_l2norm_kernel.cu similarity index 100% rename from colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu rename to extensions/csrc/cuda/multi_tensor_l2norm_kernel.cu diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu b/extensions/csrc/cuda/multi_tensor_lamb.cu similarity index 100% rename from colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu rename to extensions/csrc/cuda/multi_tensor_lamb.cu diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu b/extensions/csrc/cuda/multi_tensor_scale_kernel.cu similarity index 100% rename from colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu rename to extensions/csrc/cuda/multi_tensor_scale_kernel.cu diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu b/extensions/csrc/cuda/multi_tensor_sgd_kernel.cu similarity index 100% rename from colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu rename to extensions/csrc/cuda/multi_tensor_sgd_kernel.cu diff --git a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp b/extensions/csrc/cuda/scaled_masked_softmax.cpp similarity index 100% rename from colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp rename to extensions/csrc/cuda/scaled_masked_softmax.cpp diff --git a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.h b/extensions/csrc/cuda/scaled_masked_softmax.h similarity index 100% rename from colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.h rename to extensions/csrc/cuda/scaled_masked_softmax.h diff --git a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax_cuda.cu b/extensions/csrc/cuda/scaled_masked_softmax_cuda.cu similarity index 100% rename from colossalai/kernel/cuda_native/csrc/scaled_masked_softmax_cuda.cu rename to extensions/csrc/cuda/scaled_masked_softmax_cuda.cu diff --git a/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.cpp b/extensions/csrc/cuda/scaled_upper_triang_masked_softmax.cpp similarity index 100% rename from colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.cpp rename to extensions/csrc/cuda/scaled_upper_triang_masked_softmax.cpp diff --git a/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.h b/extensions/csrc/cuda/scaled_upper_triang_masked_softmax.h similarity index 100% rename from colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.h rename to extensions/csrc/cuda/scaled_upper_triang_masked_softmax.h diff --git a/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax_cuda.cu b/extensions/csrc/cuda/scaled_upper_triang_masked_softmax_cuda.cu similarity index 100% rename from colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax_cuda.cu rename to extensions/csrc/cuda/scaled_upper_triang_masked_softmax_cuda.cu diff --git a/colossalai/kernel/cuda_native/csrc/type_shim.h b/extensions/csrc/cuda/type_shim.h similarity index 100% rename from colossalai/kernel/cuda_native/csrc/type_shim.h rename to extensions/csrc/cuda/type_shim.h diff --git a/colossalai/kernel/cuda_native/scaled_softmax.py b/extensions/csrc/scaled_softmax.py similarity index 94% rename from colossalai/kernel/cuda_native/scaled_softmax.py rename to extensions/csrc/scaled_softmax.py index 26a5bce16..7c220d60d 100644 --- a/colossalai/kernel/cuda_native/scaled_softmax.py +++ b/extensions/csrc/scaled_softmax.py @@ -6,8 +6,7 @@ import enum import torch import torch.nn as nn -from colossalai.kernel.op_builder.scaled_masked_softmax import ScaledMaskedSoftmaxBuilder -from colossalai.kernel.op_builder.scaled_upper_triangle_masked_softmax import ScaledUpperTrainglemaskedSoftmaxBuilder +from colossalai.kernel.kernel_loader import ScaledMaskedSoftmaxLoader, ScaledUpperTriangleMaskedSoftmaxLoader try: from colossalai._C import scaled_masked_softmax, scaled_upper_triang_masked_softmax @@ -35,7 +34,7 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): def forward(ctx, inputs, scale): global scaled_upper_triang_masked_softmax if scaled_upper_triang_masked_softmax: - scaled_upper_triang_masked_softmax = ScaledUpperTrainglemaskedSoftmaxBuilder().load() + scaled_upper_triang_masked_softmax = ScaledUpperTriangleMaskedSoftmaxLoader().load() scale_t = torch.tensor([scale]) softmax_results = scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0]) @@ -67,7 +66,7 @@ class ScaledMaskedSoftmax(torch.autograd.Function): # build and load kernel if not pre-built global scaled_masked_softmax if scaled_masked_softmax is None: - scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load() + scaled_masked_softmax = ScaledMaskedSoftmaxLoader().load() softmax_results = scaled_masked_softmax.forward(inputs, mask, scale_t[0]) ctx.save_for_backward(softmax_results, scale_t) diff --git a/extensions/cuda_extension.py b/extensions/cuda_extension.py new file mode 100644 index 000000000..b5e8a285b --- /dev/null +++ b/extensions/cuda_extension.py @@ -0,0 +1,106 @@ +import os +from abc import abstractmethod +from typing import List + +from .cpp_extension import _CppExtension +from .utils import check_pytorch_version, check_system_pytorch_cuda_match, set_cuda_arch_list + +__all__ = ["_CudaExtension"] + +# Some constants for installation checks +MIN_PYTORCH_VERSION_MAJOR = 1 +MIN_PYTORCH_VERSION_MINOR = 10 + + +class _CudaExtension(_CppExtension): + @abstractmethod + def nvcc_flags(self) -> List[str]: + """ + This function should return a list of nvcc compilation flags for extensions. + """ + + def is_hardware_available(self) -> bool: + # cuda extension can only be built if cuda is available + try: + import torch + + cuda_available = torch.cuda.is_available() + except: + cuda_available = False + return cuda_available + + def assert_hardware_compatible(self) -> None: + from torch.utils.cpp_extension import CUDA_HOME + + if not CUDA_HOME: + raise AssertionError( + "[extension] CUDA_HOME is not found. You need to export CUDA_HOME environment variable or install CUDA Toolkit first in order to build/load CUDA extensions" + ) + check_system_pytorch_cuda_match(CUDA_HOME) + check_pytorch_version(MIN_PYTORCH_VERSION_MAJOR, MIN_PYTORCH_VERSION_MINOR) + + def get_cuda_home_include(self): + """ + return include path inside the cuda home. + """ + from torch.utils.cpp_extension import CUDA_HOME + + if CUDA_HOME is None: + raise RuntimeError("CUDA_HOME is None, please set CUDA_HOME to compile C++/CUDA kernels in ColossalAI.") + cuda_include = os.path.join(CUDA_HOME, "include") + return cuda_include + + def build_jit(self) -> None: + from torch.utils.cpp_extension import CUDA_HOME, load + + set_cuda_arch_list(CUDA_HOME) + + # get build dir + build_directory = _Extension.get_jit_extension_folder_path() + build_directory = Path(build_directory) + build_directory.mkdir(parents=True, exist_ok=True) + + # check if the kernel has been built + compiled_before = False + kernel_file_path = build_directory.joinpath(f"{self.name}.o") + if kernel_file_path.exists(): + compiled_before = True + + # load the kernel + if compiled_before: + print(f"[extension] Loading the JIT-built {self.name} kernel during runtime now") + else: + print(f"[extension] Compiling the JIT {self.name} kernel during runtime now") + + build_start = time.time() + op_kernel = load( + name=self.name, + sources=self.strip_empty_entries(self.sources_files()), + extra_include_paths=self.strip_empty_entries(self.include_dirs()), + extra_cflags=self.cxx_flags(), + extra_cuda_cflags=self.nvcc_flags(), + extra_ldflags=[], + build_directory=str(build_directory), + ) + build_duration = time.time() - build_start + + if compiled_before: + print(f"[extension] Time taken to load {self.name} op: {build_duration} seconds") + else: + print(f"[extension] Time taken to compile {self.name} op: {build_duration} seconds") + + return op_kernel + + def build_aot(self) -> "CUDAExtension": + from torch.utils.cpp_extension import CUDA_HOME, CUDAExtension + + set_cuda_arch_list(CUDA_HOME) + return CUDAExtension( + name=self.prebuilt_import_path, + sources=self.strip_empty_entries(self.sources_files()), + include_dirs=self.strip_empty_entries(self.include_dirs()), + extra_compile_args={ + "cxx": self.strip_empty_entries(self.cxx_flags()), + "nvcc": self.strip_empty_entries(self.nvcc_flags()), + }, + ) diff --git a/extensions/flash_attention/__init__.py b/extensions/flash_attention/__init__.py new file mode 100644 index 000000000..18abb6191 --- /dev/null +++ b/extensions/flash_attention/__init__.py @@ -0,0 +1,20 @@ +from .flash_attention_dao_cuda import FlashAttentionDaoCudaExtension +from .flash_attention_npu import FlashAttentionNpuExtension +from .flash_attention_xformers_cuda import FlashAttentionXformersCudaExtension + +try: + import flash_attention # noqa + + HAS_FLASH_ATTN = True +except: + HAS_FLASH_ATTN = False + +try: + import xformers # noqa + + HAS_MEM_EFF_ATTN = True +except: + HAS_MEM_EFF_ATTN = False + + +__all__ = ["FlashAttentionDaoCudaExtension", "FlashAttentionXformersCudaExtension", "FlashAttentionNpuExtension"] diff --git a/extensions/flash_attention/flash_attention_dao_cuda.py b/extensions/flash_attention/flash_attention_dao_cuda.py new file mode 100644 index 000000000..1b7f8ac47 --- /dev/null +++ b/extensions/flash_attention/flash_attention_dao_cuda.py @@ -0,0 +1,93 @@ +from ..base_extension import _Extension + + +class FlashAttentionDaoCudaExtension(_Extension): + def __init__(self): + super().__init__(name="flash_attention_dao_cuda", support_aot=False, support_jit=False, priority=10) + + def is_hardware_available(self) -> bool: + # cuda extension can only be built if cuda is available + try: + import torch + + cuda_available = torch.cuda.is_available() + except: + cuda_available = False + return cuda_available + + def assert_hardware_compatible(self) -> bool: + pass + + def build_aot(self) -> None: + raise NotImplementedError( + "We rely on the third-party flash-attn library for flash attention (https://github.com/Dao-AILab/flash-attention). Please install flash-attn via 'pip install flash-attn --no-build-isolation'." + ) + + def build_jit(self) -> None: + raise NotImplementedError( + "We rely on the third-party flash-attn library for flash attention (https://github.com/Dao-AILab/flash-attention). Please install flash-attn via 'pip install flash-attn --no-build-isolation'" + ) + + def load(self): + try: + from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func + except ImportError: + raise ModuleNotFoundError( + ( + "We rely on the third-party flash-attn library for flash attention. Please install flash-attn via 'pip install flash-attn --no-build-isolation'" + ) + ) + + from typing import Optional + + import torch + + def flash_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_len_info_q: "SeqLenInfo", + seq_len_info_kv: "SeqLenInfo", + origin_attn_mask: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + scale: float = None, + causal: bool = False, + padded: bool = False, + ): + """ + Arguments: + q: (batch, q_seqlen, nheads, headdim) + k: (batch, kv_seqlen, nheads, headdim) + v: (batch, kv_seqlen, nheads, headdim) + batch_size: int. + seq_len: int. + dropout_p: float. Dropout probability. + sm_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + Return: + attn_out: (batch, q_seqlen, nheads, headdim). + """ + # check if the input is in allowed dtypes + if padded: + if seq_len_info_kv == None: + seq_len_info_kv = seq_len_info_q + + attn_out = flash_attn_varlen_func( + q, + k, + v, + seq_len_info_q.cu_seqlens, + seq_len_info_kv.cu_seqlens, + seq_len_info_q.max_seqlen, + seq_len_info_kv.max_seqlen, + dropout_p, + scale, + causal, + ) + else: + attn_out = flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=scale, causal=causal) + return attn_out + + return flash_attention diff --git a/extensions/flash_attention/flash_attention_npu.py b/extensions/flash_attention/flash_attention_npu.py new file mode 100644 index 000000000..58d0f9306 --- /dev/null +++ b/extensions/flash_attention/flash_attention_npu.py @@ -0,0 +1,73 @@ +from ..base_extension import _Extension + + +class FlashAttentionNpuExtension(_Extension): + def __init__(self): + super().__init__(name="flash_attention_npu", support_aot=False, support_jit=False) + + def is_hardware_available(self) -> bool: + try: + import torch_npu # noqa + + return True + except: + return False + + def assert_hardware_compatible(self) -> bool: + pass + + def build_aot(self) -> None: + raise NotImplementedError( + "Flash Attention NPU does not require ahead-of-time compilation. Please use it by installing torch_npu." + ) + + def build_jit(self) -> None: + raise NotImplementedError( + "Flash Attention NPU does not require just-in-time compilation. Please use it by installing torch_npu." + ) + + def load(self): + import torch + from einops import rearrange + + def npu_sdpa_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_len_info_q=None, + seq_len_info_kv=None, + origin_attn_mask: torch.Tensor = None, + dropout_p: float = 0.0, + scale: float = 1.0, + causal=None, + padded=None, + ): + """ + The scaled dot product attention. + + Arguments: + q: (batch, q_seqlen, nheads, headdim) + k: (batch, kv_seqlen, nheads, headdim) + v: (batch, kv_seqlen, nheads, headdim) + batch_size: int. + seq_len: int. + dropout_p: float. Dropout probability. + scale: float. The scaling of QK^T before applying softmax. + Default to 1. + Return: + attn_out: (batch, q_seqlen, nheads, headdim). + """ + q, k, v = [rearrange(x, "b s h d -> b h s d").contiguous() for x in (q, k, v)] + output = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=origin_attn_mask, + dropout_p=dropout_p, + is_causal=origin_attn_mask is None, + scale=scale, + ) + output = rearrange(output, "b h s d -> b s (h d)") + return output + + return npu_sdpa_attention diff --git a/extensions/flash_attention/flash_attention_xformers_cuda.py b/extensions/flash_attention/flash_attention_xformers_cuda.py new file mode 100644 index 000000000..27cd823de --- /dev/null +++ b/extensions/flash_attention/flash_attention_xformers_cuda.py @@ -0,0 +1,94 @@ +from ..base_extension import _Extension + + +class FlashAttentionXformersCudaExtension(_Extension): + def __init__(self): + super().__init__(name="flash_attention_xformers_cuda", support_aot=False, support_jit=False) + + def is_hardware_available(self) -> bool: + # cuda extension can only be built if cuda is available + try: + import torch + + cuda_available = torch.cuda.is_available() + except: + cuda_available = False + return cuda_available + + def assert_hardware_compatible(self) -> bool: + pass + + def build_aot(self) -> None: + raise NotImplementedError( + "We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme." + ) + + def build_jit(self) -> None: + raise NotImplementedError( + "We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme." + ) + + def load(self): + try: + from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp, memory_efficient_attention + from xformers.ops.fmha.attn_bias import ( + BlockDiagonalCausalMask, + BlockDiagonalMask, + LowerTriangularMask, + LowerTriangularMaskWithTensorBias, + ) + except ImportError: + raise ModuleNotFoundError( + ( + "We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme." + ) + ) + from typing import Optional + + import torch + + allow_alibi = True + for op in MemoryEfficientAttentionCutlassOp: + allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES) + + def mem_eff_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_len_info_q: "SeqLenInfo", + seq_len_info_kv: "SeqLenInfo", + origin_attn_mask: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + scale: float = None, + causal: bool = False, + padded: bool = False, + ): + attn_bias = None + if padded: # bert style + if not causal: + attn_bias = BlockDiagonalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens) + else: + attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens) + elif causal: # gpt style + attn_bias = LowerTriangularMask() + + if bias is not None: # alibi / relative position embedding + assert allow_alibi, "flash attention with bias is not supported in this system." + assert causal, "attention with bias is only supported for causal attention so far." + attn_bias = attn_bias.add_bias(bias) + + if padded: + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + + out = memory_efficient_attention(q, k, v, attn_bias=attn_bias, p=dropout_p, scale=scale) + + # shape: (b*s, n, d) + if padded: + out = out.squeeze(0) + + return out + + return mem_eff_attention diff --git a/extensions/layernorm/__init__.py b/extensions/layernorm/__init__.py new file mode 100644 index 000000000..9d1bd2d01 --- /dev/null +++ b/extensions/layernorm/__init__.py @@ -0,0 +1,3 @@ +from .layernorm_cuda import LayerNormCudaExtension + +__all__ = ["LayerNormCudaExtension"] \ No newline at end of file diff --git a/extensions/layernorm/layernorm_cuda.py b/extensions/layernorm/layernorm_cuda.py new file mode 100644 index 000000000..db5f2fce1 --- /dev/null +++ b/extensions/layernorm/layernorm_cuda.py @@ -0,0 +1,24 @@ +from ..cuda_extension import _CudaExtension +from ..utils import append_nvcc_threads, get_cuda_cc_flag + + +class LayerNormCudaExtension(_CudaExtension): + def __init__(self): + super().__init__(name="layernorm_cuda") + + def sources_files(self): + ret = [self.csrc_abs_path(fname) for fname in ["cuda/layer_norm_cuda.cpp", "cuda/layer_norm_cuda_kernel.cu"]] + return ret + + def include_dirs(self): + ret = [self.get_cuda_home_include()] + return ret + + def cxx_flags(self): + return ["-O3"] + self.version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = ["-maxrregcount=50"] + extra_cuda_flags.extend(get_cuda_cc_flag()) + ret = ["-O3", "--use_fast_math"] + extra_cuda_flags + self.version_dependent_macros + return append_nvcc_threads(ret) diff --git a/extensions/moe/__init__.py b/extensions/moe/__init__.py new file mode 100644 index 000000000..962084d4b --- /dev/null +++ b/extensions/moe/__init__.py @@ -0,0 +1,3 @@ +from .moe_cuda import MoeCudaExtension + +__all__ = ['MoeCudaExtension'] \ No newline at end of file diff --git a/op_builder/moe.py b/extensions/moe/moe_cuda.py similarity index 56% rename from op_builder/moe.py rename to extensions/moe/moe_cuda.py index 6f8028b17..52883e97f 100644 --- a/op_builder/moe.py +++ b/extensions/moe/moe_cuda.py @@ -1,20 +1,17 @@ -from .builder import Builder -from .utils import append_nvcc_threads, get_cuda_cc_flag +from ..cuda_extension import _CudaExtension +from ..utils import append_nvcc_threads, get_cuda_cc_flag -class MOEBuilder(Builder): - NAME = "moe" - PREBUILT_IMPORT_PATH = "colossalai._C.moe" - +class MoeCudaExtension(_CudaExtension): def __init__(self): - super().__init__(name=MOEBuilder.NAME, prebuilt_import_path=MOEBuilder.PREBUILT_IMPORT_PATH) + super().__init__(name="moe_cuda") def include_dirs(self): - ret = [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] + ret = [self.csrc_abs_path("cuda/include"), self.get_cuda_home_include()] return ret def sources_files(self): - ret = [self.csrc_abs_path(fname) for fname in ["moe_cuda.cpp", "moe_cuda_kernel.cu"]] + ret = [self.csrc_abs_path(fname) for fname in ["cuda/moe_cuda.cpp", "cuda/moe_cuda_kernel.cu"]] return ret def cxx_flags(self): diff --git a/extensions/optimizer/__init__.py b/extensions/optimizer/__init__.py new file mode 100644 index 000000000..9c8e87cae --- /dev/null +++ b/extensions/optimizer/__init__.py @@ -0,0 +1,3 @@ +from .fused_optimizer_cuda import FusedOptimizerCudaExtension + +__all__ = ['FusedOptimizerCudaExtension'] \ No newline at end of file diff --git a/extensions/optimizer/fused_optimizer_cuda.py b/extensions/optimizer/fused_optimizer_cuda.py new file mode 100644 index 000000000..e065cf34a --- /dev/null +++ b/extensions/optimizer/fused_optimizer_cuda.py @@ -0,0 +1,34 @@ +from ..cuda_extension import _CudaExtension +from ..utils import get_cuda_cc_flag + + +class FusedOptimizerCudaExtension(_CudaExtension): + def __init__(self): + super().__init__(name="fused_optim_cuda") + + def sources_files(self): + ret = [ + self.csrc_abs_path(fname) + for fname in [ + "cuda/colossal_C_frontend.cpp", + "cuda/multi_tensor_sgd_kernel.cu", + "cuda/multi_tensor_scale_kernel.cu", + "cuda/multi_tensor_adam.cu", + "cuda/multi_tensor_l2norm_kernel.cu", + "cuda/multi_tensor_lamb.cu", + ] + ] + return ret + + def include_dirs(self): + ret = [self.get_cuda_home_include()] + return ret + + def cxx_flags(self): + version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] + return ["-O3"] + version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = ["-lineinfo"] + extra_cuda_flags.extend(get_cuda_cc_flag()) + return ["-O3", "--use_fast_math"] + extra_cuda_flags diff --git a/extensions/softmax/__init__.py b/extensions/softmax/__init__.py new file mode 100644 index 000000000..9bc50c6cd --- /dev/null +++ b/extensions/softmax/__init__.py @@ -0,0 +1,4 @@ +from .scaled_masked_softmax_cuda import ScaledMaskedSoftmaxCudaExtension +from .scaled_upper_triangle_masked_softmax_cuda import ScaledUpperTriangleMaskedSoftmaxCudaExtension + +__all__ = ['ScaledMaskedSoftmaxCudaExtension', 'ScaledUpperTriangleMaskedSoftmaxCudaExtension'] \ No newline at end of file diff --git a/op_builder/scaled_masked_softmax.py b/extensions/softmax/scaled_masked_softmax_cuda.py similarity index 50% rename from op_builder/scaled_masked_softmax.py rename to extensions/softmax/scaled_masked_softmax_cuda.py index d9239a80e..5b4208dba 100644 --- a/op_builder/scaled_masked_softmax.py +++ b/extensions/softmax/scaled_masked_softmax_cuda.py @@ -1,23 +1,20 @@ -from .builder import Builder -from .utils import append_nvcc_threads +from ..cuda_extension import _CudaExtension +from ..utils import append_nvcc_threads -class ScaledMaskedSoftmaxBuilder(Builder): - NAME = "scaled_masked_softmax" - PREBUILT_IMPORT_PATH = "colossalai._C.scaled_masked_softmax" - +class ScaledMaskedSoftmaxCudaExtension(_CudaExtension): def __init__(self): - super().__init__( - name=ScaledMaskedSoftmaxBuilder.NAME, prebuilt_import_path=ScaledMaskedSoftmaxBuilder.PREBUILT_IMPORT_PATH - ) + super().__init__(name="scaled_masked_softmax_cuda") - # necessary 4 functions def sources_files(self): - ret = [self.csrc_abs_path(fname) for fname in ["scaled_masked_softmax.cpp", "scaled_masked_softmax_cuda.cu"]] + ret = [ + self.csrc_abs_path(fname) + for fname in ["cuda/scaled_masked_softmax.cpp", "cuda/scaled_masked_softmax_cuda.cu"] + ] return ret def include_dirs(self): - return [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] + return [self.get_cuda_home_include()] def cxx_flags(self): return ["-O3"] + self.version_dependent_macros diff --git a/extensions/softmax/scaled_upper_triangle_masked_softmax_cuda.py b/extensions/softmax/scaled_upper_triangle_masked_softmax_cuda.py new file mode 100644 index 000000000..d4f27a921 --- /dev/null +++ b/extensions/softmax/scaled_upper_triangle_masked_softmax_cuda.py @@ -0,0 +1,34 @@ +from ..cuda_extension import _CudaExtension +from ..utils import append_nvcc_threads, get_cuda_cc_flag + + +class ScaledUpperTriangleMaskedSoftmaxCudaExtension(_CudaExtension): + def __init__(self): + super().__init__(name="scaled_upper_triangle_masked_softmax_cuda") + + def include_dirs(self): + return [self.get_cuda_home_include()] + + def sources_files(self): + ret = [ + self.csrc_abs_path(fname) + for fname in [ + "cuda/scaled_upper_triang_masked_softmax.cpp", + "cuda/scaled_upper_triang_masked_softmax_cuda.cu", + ] + ] + return ret + + def cxx_flags(self): + return ["-O3"] + self.version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = [ + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + ] + extra_cuda_flags.extend(get_cuda_cc_flag()) + ret = ["-O3", "--use_fast_math"] + extra_cuda_flags + return append_nvcc_threads(ret) diff --git a/extensions/triton_extension.py b/extensions/triton_extension.py new file mode 100644 index 000000000..9f0792f8c --- /dev/null +++ b/extensions/triton_extension.py @@ -0,0 +1,21 @@ +from .base_extension import _Extension + +__all__ = ["_TritonExtension"] + + +class _TritonExtension(_Extension): + def __init__(self, name: str, priority: int = 1): + super().__init__(name, support_aot=False, support_jit=True, priority=priority) + + def is_hardware_compatible(self) -> bool: + # cuda extension can only be built if cuda is available + try: + import torch + + cuda_available = torch.cuda.is_available() + except: + cuda_available = False + return cuda_available + + def load(self): + return self.build_jit() diff --git a/op_builder/utils.py b/extensions/utils.py similarity index 100% rename from op_builder/utils.py rename to extensions/utils.py diff --git a/op_builder/README.md b/op_builder/README.md deleted file mode 100644 index 9c33a4a32..000000000 --- a/op_builder/README.md +++ /dev/null @@ -1,32 +0,0 @@ -# Build PyTorch Extensions - -## Overview - -Building PyTorch extensions can be a difficult task for users not from the system background. It is definitely frustrating if the users encounter many strange technical jargons when install Colossal-AI. Therefore, we will provide two methods of building the PyTorch extensions for the users. - -1. Build CUDA extensions when running `pip install` if `CUDA_EXT=1` -2. Build the extension during runtime - -The first method is more suitable for users who are familiar with CUDA environment configurations. The second method is for those who are not as they only need to build the kernel which is required by their program. - -These two methods have different advantages and disadvantages. -Method 1 is good because it allows the user to build all kernels during installation and directly import the kernel. They don't need to care about kernel building when running their program. However, installation may fail if they don't know how to configure their environments and this leads to much frustration. -Method 2 is good because it allows the user to only build the kernel they actually need, such that there is a lower probability that they encounter environment issue. However, it may slow down their program due to the first build and subsequence load. - -## PyTorch Extensions in Colossal-AI - -The project [DeepSpeed](https://github.com/microsoft/DeepSpeed) has proposed a [solution](https://github.com/microsoft/DeepSpeed/tree/master/op_builder) to support kernel-build during either installation or runtime. -We have adapted from DeepSpeed's solution to build extensions. The extension build requires two main functions from PyTorch: - -1. `torch.utils.cpp_extension.CUDAExtension`: used to build extensions in `setup.py` during `pip install`. -2. `torch.utils.cpp_extension.load`: used to build and load extension during runtime - -Please note that the extension build by `CUDAExtension` cannot be loaded by the `load` function and `load` will run its own build again (correct me if I am wrong). - -Based on the DeepSpeed's work, we have make several modifications and improvements: - -1. All pre-built kernels (those installed with `setup.py`) will be found in `colossalai._C` -2. All runtime-built kernels will be found in the default torch extension path, i.e. ~/.cache/colossalai/torch_extensions. (If we put the built kernels in the installed site-package directory, this will make pip uninstall incomplete) -3. Once a kernel is loaded, we will cache it in the builder to avoid repeated kernel loading. - -When loading the built kernel, we will first check if the pre-built one exists. If not, the runtime build will be triggered. diff --git a/op_builder/__init__.py b/op_builder/__init__.py deleted file mode 100644 index 21e216437..000000000 --- a/op_builder/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -from .arm_cpu_adam import ArmCPUAdamBuilder -from .cpu_adam import CPUAdamBuilder -from .fused_optim import FusedOptimBuilder -from .layernorm import LayerNormBuilder -from .moe import MOEBuilder -from .multi_head_attn import MultiHeadAttnBuilder -from .scaled_masked_softmax import ScaledMaskedSoftmaxBuilder -from .scaled_upper_triangle_masked_softmax import ScaledUpperTrainglemaskedSoftmaxBuilder - -ALL_OPS = { - "cpu_adam": CPUAdamBuilder, - "fused_optim": FusedOptimBuilder, - "moe": MOEBuilder, - "multi_head_attn": MultiHeadAttnBuilder, - "scaled_masked_softmax": ScaledMaskedSoftmaxBuilder, - "scaled_upper_triangle_masked_softmax": ScaledUpperTrainglemaskedSoftmaxBuilder, - "layernorm": LayerNormBuilder, -} - -__all__ = [ - "ALL_OPS", - "CPUAdamBuilder", - "FusedOptimBuilder", - "MultiHeadAttnBuilder", - "ScaledMaskedSoftmaxBuilder", - "ScaledUpperTrainglemaskedSoftmaxBuilder", - "MOEBuilder", - "MultiTensorSGDBuilder", - "MultiTensorAdamBuilder", - "MultiTensorLambBuilder", - "MultiTensorScaleBuilder", - "MultiTensorL2NormBuilder", - "ArmCPUAdamBuilder", -] diff --git a/op_builder/arm_cpu_adam.py b/op_builder/arm_cpu_adam.py deleted file mode 100644 index 18dd519fa..000000000 --- a/op_builder/arm_cpu_adam.py +++ /dev/null @@ -1,34 +0,0 @@ -from .builder import Builder - - -class ArmCPUAdamBuilder(Builder): - NAME = "arm_cpu_adam" - PREBUILT_IMPORT_PATH = "colossalai._C.arm_cpu_adam" - ext_type = "cpu" - - def __init__(self): - super().__init__(name=ArmCPUAdamBuilder.NAME, prebuilt_import_path=ArmCPUAdamBuilder.PREBUILT_IMPORT_PATH) - self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] - - # necessary 4 functions - def sources_files(self): - ret = [ - self.csrc_abs_path("cpu_adam_arm.cpp"), - ] - return ret - - def include_dirs(self): - return [self.csrc_abs_path("includes")] - - def cxx_flags(self): - extra_cxx_flags = [ - "-std=c++14", - "-std=c++17", - "-g", - "-Wno-reorder", - "-fopenmp", - ] - return ["-O3"] + self.version_dependent_macros + extra_cxx_flags - - def nvcc_flags(self): - return [] diff --git a/op_builder/builder.py b/op_builder/builder.py deleted file mode 100644 index d804cb160..000000000 --- a/op_builder/builder.py +++ /dev/null @@ -1,236 +0,0 @@ -# This code has been adapted from the DeepSpeed library. -# Copyright (c) Microsoft Corporation. - -# Licensed under the MIT License. -import importlib -import os -import time -from abc import ABC, abstractmethod -from pathlib import Path -from typing import List, Optional, Union - -from .utils import check_cuda_availability, check_system_pytorch_cuda_match, print_rank_0 - - -class Builder(ABC): - """ - Builder is the base class to build extensions for PyTorch. - - Args: - name (str): the name of the kernel to be built - prebuilt_import_path (str): the path where the extension is installed during pip install - """ - - ext_type: str = "cuda" - - def __init__(self, name: str, prebuilt_import_path: str): - self.name = name - self.prebuilt_import_path = prebuilt_import_path - self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] - - # we store the op as an attribute to avoid repeated building and loading - self.cached_op_module = None - - assert prebuilt_import_path.startswith( - "colossalai._C" - ), f"The prebuilt_import_path should start with colossalai._C, but got {self.prebuilt_import_path}" - - def relative_to_abs_path(self, code_path: str) -> str: - """ - This function takes in a path relative to the colossalai root directory and return the absolute path. - """ - op_builder_module_path = Path(__file__).parent - - # if we install from source - # the current file path will be op_builder/builder.py - # if we install via pip install colossalai - # the current file path will be colossalai/kernel/op_builder/builder.py - # this is because that the op_builder inside colossalai is a symlink - # this symlink will be replaced with actual files if we install via pypi - # thus we cannot tell the colossalai root directory by checking whether the op_builder - # is a symlink, we can only tell whether it is inside or outside colossalai - if str(op_builder_module_path).endswith("colossalai/kernel/op_builder"): - root_path = op_builder_module_path.parent.parent - else: - root_path = op_builder_module_path.parent.joinpath("colossalai") - - code_abs_path = root_path.joinpath(code_path) - return str(code_abs_path) - - def get_cuda_home_include(self): - """ - return include path inside the cuda home. - """ - from torch.utils.cpp_extension import CUDA_HOME - - if CUDA_HOME is None: - raise RuntimeError("CUDA_HOME is None, please set CUDA_HOME to compile C++/CUDA kernels in ColossalAI.") - cuda_include = os.path.join(CUDA_HOME, "include") - return cuda_include - - def csrc_abs_path(self, path): - return os.path.join(self.relative_to_abs_path("kernel/cuda_native/csrc"), path) - - # functions must be overrided begin - @abstractmethod - def sources_files(self) -> List[str]: - """ - This function should return a list of source files for extensions. - """ - raise NotImplementedError - - @abstractmethod - def include_dirs(self) -> List[str]: - """ - This function should return a list of include files for extensions. - """ - - @abstractmethod - def cxx_flags(self) -> List[str]: - """ - This function should return a list of cxx compilation flags for extensions. - """ - - @abstractmethod - def nvcc_flags(self) -> List[str]: - """ - This function should return a list of nvcc compilation flags for extensions. - """ - - # functions must be overrided over - def strip_empty_entries(self, args): - """ - Drop any empty strings from the list of compile and link flags - """ - return [x for x in args if len(x) > 0] - - def import_op(self): - """ - This function will import the op module by its string name. - """ - return importlib.import_module(self.prebuilt_import_path) - - def check_runtime_build_environment(self): - """ - Check whether the system environment is ready for extension compilation. - """ - try: - from torch.utils.cpp_extension import CUDA_HOME - - TORCH_AVAILABLE = True - except ImportError: - TORCH_AVAILABLE = False - CUDA_HOME = None - - if not TORCH_AVAILABLE: - raise ModuleNotFoundError( - "PyTorch is not found. You need to install PyTorch first in order to build CUDA extensions" - ) - - if CUDA_HOME is None: - raise RuntimeError( - "CUDA_HOME is not found. You need to export CUDA_HOME environment variable or install CUDA Toolkit first in order to build CUDA extensions" - ) - - # make sure CUDA is available for compilation during - cuda_available = check_cuda_availability() - if not cuda_available: - raise RuntimeError("CUDA is not available on your system as torch.cuda.is_available() returns False.") - - # make sure system CUDA and pytorch CUDA match, an error will raised inside the function if not - check_system_pytorch_cuda_match(CUDA_HOME) - - def load(self, verbose: Optional[bool] = None): - """ - load the kernel during runtime. If the kernel is not built during pip install, it will build the kernel. - If the kernel is built during runtime, it will be stored in `~/.cache/colossalai/torch_extensions/`. If the - kernel is built during pip install, it can be accessed through `colossalai._C`. - - Warning: do not load this kernel repeatedly during model execution as it could slow down the training process. - - Args: - verbose (bool, optional): show detailed info. Defaults to True. - """ - if verbose is None: - verbose = os.environ.get("CAI_KERNEL_VERBOSE", "0") == "1" - # if the kernel has be compiled and cached, we directly use it - if self.cached_op_module is not None: - return self.cached_op_module - - try: - # if the kernel has been pre-built during installation - # we just directly import it - op_module = self.import_op() - if verbose: - print_rank_0( - f"[extension] OP {self.prebuilt_import_path} has been compiled ahead of time, skip building." - ) - except ImportError: - # check environment - if self.ext_type == "cuda": - self.check_runtime_build_environment() - - # time the kernel compilation - start_build = time.time() - - # construct the build directory - import torch - from torch.utils.cpp_extension import load - - torch_version_major = torch.__version__.split(".")[0] - torch_version_minor = torch.__version__.split(".")[1] - torch_cuda_version = torch.version.cuda - home_directory = os.path.expanduser("~") - extension_directory = f".cache/colossalai/torch_extensions/torch{torch_version_major}.{torch_version_minor}_cu{torch_cuda_version}" - build_directory = os.path.join(home_directory, extension_directory) - Path(build_directory).mkdir(parents=True, exist_ok=True) - - if verbose: - print_rank_0(f"[extension] Compiling or loading the JIT-built {self.name} kernel during runtime now") - - # load the kernel - op_module = load( - name=self.name, - sources=self.strip_empty_entries(self.sources_files()), - extra_include_paths=self.strip_empty_entries(self.include_dirs()), - extra_cflags=self.cxx_flags(), - extra_cuda_cflags=self.nvcc_flags(), - extra_ldflags=[], - build_directory=build_directory, - verbose=verbose, - ) - - build_duration = time.time() - start_build - - # log jit compilation time - if verbose: - print_rank_0(f"[extension] Time to compile or load {self.name} op: {build_duration} seconds") - - # cache the built/loaded kernel - self.cached_op_module = op_module - - return op_module - - def builder(self) -> Union["CUDAExtension", "CppExtension"]: - """ - get a CUDAExtension instance used for setup.py - """ - from torch.utils.cpp_extension import CppExtension, CUDAExtension - - if self.ext_type == "cpp": - return CppExtension( - name=self.prebuilt_import_path, - sources=self.strip_empty_entries(self.sources_files()), - include_dirs=self.strip_empty_entries(self.include_dirs()), - extra_compile_args=self.strip_empty_entries(self.cxx_flags()), - ) - - return CUDAExtension( - name=self.prebuilt_import_path, - sources=self.strip_empty_entries(self.sources_files()), - include_dirs=self.strip_empty_entries(self.include_dirs()), - extra_compile_args={ - "cxx": self.strip_empty_entries(self.cxx_flags()), - "nvcc": self.strip_empty_entries(self.nvcc_flags()), - }, - ) diff --git a/op_builder/fused_optim.py b/op_builder/fused_optim.py deleted file mode 100644 index 3baa0880d..000000000 --- a/op_builder/fused_optim.py +++ /dev/null @@ -1,37 +0,0 @@ -from .builder import Builder -from .utils import get_cuda_cc_flag - - -class FusedOptimBuilder(Builder): - NAME = "fused_optim" - PREBUILT_IMPORT_PATH = "colossalai._C.fused_optim" - - def __init__(self): - super().__init__(name=FusedOptimBuilder.NAME, prebuilt_import_path=FusedOptimBuilder.PREBUILT_IMPORT_PATH) - - def sources_files(self): - ret = [ - self.csrc_abs_path(fname) - for fname in [ - "colossal_C_frontend.cpp", - "multi_tensor_sgd_kernel.cu", - "multi_tensor_scale_kernel.cu", - "multi_tensor_adam.cu", - "multi_tensor_l2norm_kernel.cu", - "multi_tensor_lamb.cu", - ] - ] - return ret - - def include_dirs(self): - ret = [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] - return ret - - def cxx_flags(self): - version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] - return ["-O3"] + version_dependent_macros - - def nvcc_flags(self): - extra_cuda_flags = ["-lineinfo"] - extra_cuda_flags.extend(get_cuda_cc_flag()) - return ["-O3", "--use_fast_math"] + extra_cuda_flags diff --git a/op_builder/gptq.py b/op_builder/gptq.py deleted file mode 100644 index a17801f87..000000000 --- a/op_builder/gptq.py +++ /dev/null @@ -1,56 +0,0 @@ -import re - -import torch - -from .builder import Builder -from .utils import append_nvcc_threads - - -class GPTQBuilder(Builder): - NAME = "cu_gptq" - PREBUILT_IMPORT_PATH = "colossalai._C.cu_gptq" - - def __init__(self): - super().__init__(name=GPTQBuilder.NAME, prebuilt_import_path=GPTQBuilder.PREBUILT_IMPORT_PATH) - - def include_dirs(self): - ret = [self.csrc_abs_path("gptq"), self.get_cuda_home_include()] - return ret - - def sources_files(self): - ret = [ - self.csrc_abs_path(fname) - for fname in [ - "gptq/linear_gptq.cpp", - "gptq/column_remap.cu", - "gptq/cuda_buffers.cu", - "gptq/q4_matmul.cu", - "gptq/q4_matrix.cu", - ] - ] - return ret - - def cxx_flags(self): - return ["-O3"] + self.version_dependent_macros - - def nvcc_flags(self): - extra_cuda_flags = [ - "-v", - "-std=c++14", - "-std=c++17", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_HALF2_OPERATORS__", - "-DTHRUST_IGNORE_CUB_VERSION_CHECK", - "-lcublas", - ] - - for arch in torch.cuda.get_arch_list(): - res = re.search(r"sm_(\d+)", arch) - if res: - arch_cap = res[1] - if int(arch_cap) >= 80: - extra_cuda_flags.extend(["-gencode", f"arch=compute_{arch_cap},code={arch}"]) - - ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags - return append_nvcc_threads(ret) diff --git a/op_builder/layernorm.py b/op_builder/layernorm.py deleted file mode 100644 index 2684c6ddb..000000000 --- a/op_builder/layernorm.py +++ /dev/null @@ -1,27 +0,0 @@ -from .builder import Builder -from .utils import append_nvcc_threads, get_cuda_cc_flag - - -class LayerNormBuilder(Builder): - NAME = "layernorm" - PREBUILT_IMPORT_PATH = "colossalai._C.layernorm" - - def __init__(self): - super().__init__(name=LayerNormBuilder.NAME, prebuilt_import_path=LayerNormBuilder.PREBUILT_IMPORT_PATH) - - def sources_files(self): - ret = [self.csrc_abs_path(fname) for fname in ["layer_norm_cuda.cpp", "layer_norm_cuda_kernel.cu"]] - return ret - - def include_dirs(self): - ret = [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] - return ret - - def cxx_flags(self): - return ["-O3"] + self.version_dependent_macros - - def nvcc_flags(self): - extra_cuda_flags = ["-maxrregcount=50"] - extra_cuda_flags.extend(get_cuda_cc_flag()) - ret = ["-O3", "--use_fast_math"] + extra_cuda_flags + self.version_dependent_macros - return append_nvcc_threads(ret) diff --git a/op_builder/multi_head_attn.py b/op_builder/multi_head_attn.py deleted file mode 100644 index cb8fc489c..000000000 --- a/op_builder/multi_head_attn.py +++ /dev/null @@ -1,46 +0,0 @@ -from .builder import Builder -from .utils import append_nvcc_threads, get_cuda_cc_flag - - -class MultiHeadAttnBuilder(Builder): - NAME = "multihead_attention" - PREBUILT_IMPORT_PATH = "colossalai._C.multihead_attention" - - def __init__(self): - super().__init__(name=MultiHeadAttnBuilder.NAME, prebuilt_import_path=MultiHeadAttnBuilder.PREBUILT_IMPORT_PATH) - - def include_dirs(self): - ret = [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] - return ret - - def sources_files(self): - ret = [ - self.csrc_abs_path(fname) - for fname in [ - "multihead_attention_1d.cpp", - "kernels/cublas_wrappers.cu", - "kernels/transform_kernels.cu", - "kernels/dropout_kernels.cu", - "kernels/normalize_kernels.cu", - "kernels/softmax_kernels.cu", - "kernels/general_kernels.cu", - "kernels/cuda_util.cu", - ] - ] - return ret - - def cxx_flags(self): - return ["-O3"] + self.version_dependent_macros - - def nvcc_flags(self): - extra_cuda_flags = [ - "-std=c++14", - "-std=c++17", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_HALF2_OPERATORS__", - "-DTHRUST_IGNORE_CUB_VERSION_CHECK", - ] - extra_cuda_flags.extend(get_cuda_cc_flag()) - ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags - return append_nvcc_threads(ret) diff --git a/op_builder/scaled_upper_triangle_masked_softmax.py b/op_builder/scaled_upper_triangle_masked_softmax.py deleted file mode 100644 index 1445230ac..000000000 --- a/op_builder/scaled_upper_triangle_masked_softmax.py +++ /dev/null @@ -1,37 +0,0 @@ -from .builder import Builder -from .utils import append_nvcc_threads, get_cuda_cc_flag - - -class ScaledUpperTrainglemaskedSoftmaxBuilder(Builder): - NAME = "scaled_upper_triangle_masked_softmax" - PREBUILT_IMPORT_PATH = "colossalai._C.scaled_upper_triangle_masked_softmax" - - def __init__(self): - super().__init__( - name=ScaledUpperTrainglemaskedSoftmaxBuilder.NAME, - prebuilt_import_path=ScaledUpperTrainglemaskedSoftmaxBuilder.PREBUILT_IMPORT_PATH, - ) - - def include_dirs(self): - return [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] - - def sources_files(self): - ret = [ - self.csrc_abs_path(fname) - for fname in ["scaled_upper_triang_masked_softmax.cpp", "scaled_upper_triang_masked_softmax_cuda.cu"] - ] - return ret - - def cxx_flags(self): - return ["-O3"] + self.version_dependent_macros - - def nvcc_flags(self): - extra_cuda_flags = [ - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - ] - extra_cuda_flags.extend(get_cuda_cc_flag()) - ret = ["-O3", "--use_fast_math"] + extra_cuda_flags - return append_nvcc_threads(ret) diff --git a/op_builder/smoothquant.py b/op_builder/smoothquant.py deleted file mode 100644 index d562a4c4f..000000000 --- a/op_builder/smoothquant.py +++ /dev/null @@ -1,52 +0,0 @@ -import torch - -from .builder import Builder -from .utils import append_nvcc_threads - - -class SmoothquantBuilder(Builder): - NAME = "cu_smoothquant" - PREBUILT_IMPORT_PATH = "colossalai._C.cu_smoothquant" - - def __init__(self): - super().__init__(name=SmoothquantBuilder.NAME, prebuilt_import_path=SmoothquantBuilder.PREBUILT_IMPORT_PATH) - - def include_dirs(self): - ret = [self.csrc_abs_path("smoothquant"), self.get_cuda_home_include()] - return ret - - def sources_files(self): - ret = [ - self.csrc_abs_path(fname) - for fname in [ - "smoothquant/binding.cpp", - "smoothquant/linear.cu", - ] - ] - return ret - - def cxx_flags(self): - return ["-O3"] + self.version_dependent_macros - - def nvcc_flags(self): - compute_capability = torch.cuda.get_device_capability() - cuda_arch = compute_capability[0] * 100 + compute_capability[1] * 10 - - extra_cuda_flags = [ - "-v", - f"-DCUDA_ARCH={cuda_arch}", - "-std=c++17", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_HALF2_OPERATORS__", - "-DTHRUST_IGNORE_CUB_VERSION_CHECK", - ] - - ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags - return append_nvcc_threads(ret) - - def builder(self): - try: - super().builder() - except: - warnings.warn("build smoothquant lib not successful") diff --git a/setup.py b/setup.py index cda1ba7ee..1244bfff0 100644 --- a/setup.py +++ b/setup.py @@ -5,55 +5,23 @@ from typing import List from setuptools import find_packages, setup -from op_builder.utils import ( - check_cuda_availability, - check_pytorch_version, - check_system_pytorch_cuda_match, - get_cuda_bare_metal_version, - get_pytorch_version, - set_cuda_arch_list, -) - try: - from torch.utils.cpp_extension import CUDA_HOME, BuildExtension + import torch # noqa + from torch.utils.cpp_extension import BuildExtension TORCH_AVAILABLE = True except ImportError: TORCH_AVAILABLE = False - CUDA_HOME = None -# Some constants for installation checks -MIN_PYTORCH_VERSION_MAJOR = 1 -MIN_PYTORCH_VERSION_MINOR = 10 THIS_DIR = os.path.dirname(os.path.abspath(__file__)) -BUILD_CUDA_EXT = int(os.environ.get("CUDA_EXT", "0")) == 1 +BUILD_EXT = int(os.environ.get("BUILD_EXT", "0")) == 1 IS_NIGHTLY = int(os.environ.get("NIGHTLY", "0")) == 1 -# a variable to store the op builder -ext_modules = [] - # we do not support windows currently if sys.platform == "win32": raise RuntimeError("Windows is not supported yet. Please try again within the Windows Subsystem for Linux (WSL).") -# check for CUDA extension dependencies -def environment_check_for_cuda_extension_build(): - if not TORCH_AVAILABLE: - raise ModuleNotFoundError( - "[extension] PyTorch is not found while CUDA_EXT=1. You need to install PyTorch first in order to build CUDA extensions" - ) - - if not CUDA_HOME: - raise RuntimeError( - "[extension] CUDA_HOME is not found while CUDA_EXT=1. You need to export CUDA_HOME environment variable or install CUDA Toolkit first in order to build CUDA extensions" - ) - - check_system_pytorch_cuda_match(CUDA_HOME) - check_pytorch_version(MIN_PYTORCH_VERSION_MAJOR, MIN_PYTORCH_VERSION_MINOR) - check_cuda_availability() - - def fetch_requirements(path) -> List[str]: """ This function reads the requirements file. @@ -98,46 +66,35 @@ def get_version() -> str: # write version into version.py with open(version_py_path, "w") as f: f.write(f"__version__ = '{version}'\n") - - # look for pytorch and cuda version - if BUILD_CUDA_EXT: - torch_major, torch_minor, _ = get_pytorch_version() - torch_version = f"{torch_major}.{torch_minor}" - cuda_version = ".".join(get_cuda_bare_metal_version(CUDA_HOME)) - else: - torch_version = None - cuda_version = None - - # write the version into the python file - if torch_version: - f.write(f'torch = "{torch_version}"\n') - else: - f.write("torch = None\n") - - if cuda_version: - f.write(f'cuda = "{cuda_version}"\n') - else: - f.write("cuda = None\n") - return version -if BUILD_CUDA_EXT: - environment_check_for_cuda_extension_build() - set_cuda_arch_list(CUDA_HOME) +if BUILD_EXT: + if not TORCH_AVAILABLE: + raise ModuleNotFoundError( + "[extension] PyTorch is not found while CUDA_EXT=1. You need to install PyTorch first in order to build CUDA extensions" + ) - from op_builder import ALL_OPS + from extensions import ALL_EXTENSIONS op_names = [] + ext_modules = [] - # load all builders - for name, builder_cls in ALL_OPS.items(): - op_names.append(name) - ext_modules.append(builder_cls().builder()) + for ext_cls in ALL_EXTENSIONS: + ext = ext_cls() + if ext.support_aot and ext.is_hardware_available(): + ext.assert_hardware_compatible() + op_names.append(ext.name) + ext_modules.append(ext.build_aot()) # show log - op_name_list = ", ".join(op_names) - print(f"[extension] loaded builders for {op_name_list}") + if len(ext_modules) == 0: + raise RuntimeError("[extension] Could not find any kernel compatible with the current environment.") + else: + op_name_list = ", ".join(op_names) + print(f"[extension] Building extensions{op_name_list}") +else: + ext_modules = [] # always put not nightly branch as the if branch # otherwise github will treat colossalai-nightly as the project name diff --git a/tests/kit/model_zoo/registry.py b/tests/kit/model_zoo/registry.py index 5e8e0b382..a16b16ad6 100644 --- a/tests/kit/model_zoo/registry.py +++ b/tests/kit/model_zoo/registry.py @@ -61,7 +61,9 @@ class ModelZooRegistry(dict): """ self[name] = (model_fn, data_gen_fn, output_transform_fn, loss_fn, model_attribute) - def get_sub_registry(self, keyword: Union[str, List[str]], exclude: Union[str, List[str]] = None): + def get_sub_registry( + self, keyword: Union[str, List[str]], exclude: Union[str, List[str]] = None, allow_empty: bool = False + ): """ Get a sub registry with models that contain the keyword. @@ -95,7 +97,8 @@ class ModelZooRegistry(dict): if not should_exclude: new_dict[k] = v - assert len(new_dict) > 0, f"No model found with keyword {keyword}" + if not allow_empty: + assert len(new_dict) > 0, f"No model found with keyword {keyword}" return new_dict diff --git a/tests/kit/model_zoo/transformers/gptj.py b/tests/kit/model_zoo/transformers/gptj.py index 9eefbb43d..c89124f01 100644 --- a/tests/kit/model_zoo/transformers/gptj.py +++ b/tests/kit/model_zoo/transformers/gptj.py @@ -63,6 +63,9 @@ config = transformers.GPTJConfig( n_layer=2, n_head=4, vocab_size=50258, + n_embd=256, + hidden_size=256, + n_positions=512, attn_pdrop=0, embd_pdrop=0, resid_pdrop=0, diff --git a/tests/test_auto_parallel/test_offload/test_perf.py b/tests/test_auto_parallel/test_offload/test_perf.py index 2c8b260e6..373ba28b8 100644 --- a/tests/test_auto_parallel/test_offload/test_perf.py +++ b/tests/test_auto_parallel/test_offload/test_perf.py @@ -5,13 +5,13 @@ import torch from torch.utils._pytree import tree_map import colossalai +from colossalai.accelerator import get_accelerator from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer from colossalai.auto_parallel.offload.mem_optimize import memory_optimize from colossalai.auto_parallel.offload.solver import NOT_NVML from colossalai.fx.profiler import parameter_size from colossalai.nn.optimizer import HybridAdam from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper from tests.test_auto_parallel.test_offload.model_utils import * from tests.test_tensor.common_utils import set_seed @@ -31,7 +31,7 @@ def exam_fwd_bwd(model_name: str, memory_budget: float, solver_name: str): 64, 8, ), - device=get_current_device(), + device=get_accelerator().get_current_device(), ) criterion = LMLoss() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py index aba746f19..d57717326 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py @@ -10,12 +10,12 @@ try: except: NO_CODEGEN = True +from colossalai.accelerator import get_accelerator from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.nn.optimizer import HybridAdam from colossalai.testing import assert_close, rerun_if_address_is_in_use, run_on_environment_flag, spawn -from colossalai.utils import get_current_device from colossalai.zero import zero_model_wrapper, zero_optim_wrapper @@ -72,7 +72,11 @@ def check_auto_parallel_with_gemini(rank, world_size, port): print("=" * msg_length) gemini_config = dict( - strict_ddp_mode=False, device=get_current_device(), placement_policy="cpu", pin_memory=True, search_range_m=128 + strict_ddp_mode=False, + device=get_accelerator().get_current_device(), + placement_policy="cpu", + pin_memory=True, + search_range_m=128, ) gm = zero_model_wrapper(gm, zero_stage=3, gemini_config=gemini_config) diff --git a/tests/test_booster/test_plugin/test_3d_plugin.py b/tests/test_booster/test_plugin/test_3d_plugin.py index e724d7359..67b0bef50 100644 --- a/tests/test_booster/test_plugin/test_3d_plugin.py +++ b/tests/test_booster/test_plugin/test_3d_plugin.py @@ -1,19 +1,49 @@ +import copy from contextlib import nullcontext from typing import Optional import torch import torch.distributed as dist +from torch.testing import assert_close +from torch.utils.data import Dataset import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import HybridParallelPlugin from colossalai.fx import is_compatible_with_meta from colossalai.lazy.lazy_init import LazyInitContext from colossalai.nn.optimizer import HybridAdam from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import set_seed from tests.kit.model_zoo import model_zoo +class RandomDataset(Dataset): + def __init__(self, num_samples: int = 100, max_length: int = 512, vocab_size: int = 32000): + self.num_samples = num_samples + self.max_length = max_length + set_seed(42) + self.input_ids = torch.randint( + 0, vocab_size, (num_samples, max_length), device=get_accelerator().get_current_device() + ) + self.attention_mask = torch.ones_like(self.input_ids) + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + return { + "input_ids": self.input_ids[idx], + "attention_mask": self.attention_mask[idx], + "labels": self.input_ids[idx], + } + + +def move_to_cuda(batch): + return {k: v.cuda() for k, v in batch.items()} + + @clear_cache_before_run() def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: try: @@ -85,10 +115,145 @@ def check_3d_plugin(init_method: str = "none", early_stop: bool = True): assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()]) +@parameterize( + "test_args", + [ + { + "batch_size": 8, + "num_steps": 4, + "tp": 2, + "pp": 2, + "pp_style": "1f1b", + "num_model_chunks": 1, + "num_microbatches": 4, + "zero": 0, + "precision": "fp16", + "initial_scale": 1, + "max_length": 512, + "gradient_accumulation_step": 2, + }, + { + "batch_size": 8, + "num_steps": 4, + "tp": 1, + "pp": 2, + "pp_style": "1f1b", + "num_model_chunks": 1, + "num_microbatches": 4, + "zero": 1, + "precision": "fp16", + "initial_scale": 1, + "max_length": 512, + "gradient_accumulation_step": 2, + }, + { + "batch_size": 1, + "num_steps": 4, + "tp": 2, + "pp": 1, + "pp_style": "1f1b", + "num_model_chunks": 1, + "num_microbatches": 1, + "zero": 2, + "precision": "fp16", + "initial_scale": 1, + "max_length": 512, + "gradient_accumulation_step": 2, + }, + { + "batch_size": 1, + "num_steps": 4, + "tp": 2, + "pp": 1, + "pp_style": "1f1b", + "num_model_chunks": 1, + "num_microbatches": 1, + "zero": 0, + "precision": "fp16", + "initial_scale": 1, + "max_length": 512, + "gradient_accumulation_step": 2, + }, + ], +) +def run_grad_acc_test(test_args): + model_fn, *_ = next(iter(model_zoo.get_sub_registry("transformers_gpt_lm").values())) + model = model_fn() + optimizer = HybridAdam(model.parameters()) + origin_model = copy.deepcopy(model).cuda() + origin_optimizer = HybridAdam(origin_model.parameters()) + + plugin = HybridParallelPlugin( + tp_size=test_args["tp"], + pp_size=test_args["pp"], + pp_style=test_args["pp_style"], + zero_stage=test_args["zero"], + num_model_chunks=test_args["num_model_chunks"], + enable_fused_normalization=True, + num_microbatches=test_args["num_microbatches"], + precision=test_args["precision"], + ) + booster = Booster(plugin=plugin) + + dataset = RandomDataset( + num_samples=test_args["batch_size"] * test_args["num_steps"] * plugin.dp_size, + max_length=test_args["max_length"], + vocab_size=model.config.vocab_size, + ) + dataloader = plugin.prepare_dataloader(dataset, batch_size=test_args["batch_size"], shuffle=True, drop_last=True) + + model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader) + + grad_accu_step = test_args["gradient_accumulation_step"] + for step, batch in enumerate(dataloader): + batch = move_to_cuda(batch) + # train origin model + origin_output = origin_model(**batch) + origin_loss = origin_output[0] / grad_accu_step + origin_loss.backward() + + if (step + 1) % grad_accu_step != 0 and test_args["zero"] != 2: + ctx = booster.no_sync(model, optimizer) + else: + ctx = nullcontext() + + with ctx: + if plugin.stage_manager is not None: + batch = iter([batch]) + booster.execute_pipeline( + batch, + model, + criterion=lambda outputs, inputs: outputs[0] / grad_accu_step, + optimizer=optimizer, + return_loss=False, + ) + else: + outputs = model(**batch) + loss = outputs[0] / grad_accu_step + booster.backward(loss, optimizer) + + if (step + 1) % grad_accu_step == 0: + # update origin model weight + origin_optimizer.step() + origin_optimizer.zero_grad() + + # update sharded model + optimizer.step() + optimizer.zero_grad() + + # tricky code here, shard the origin model inorder to check the parameters in the same stage. + origin_model, origin_optimizer, _, dataloader, _ = booster.boost( + origin_model, origin_optimizer, dataloader=dataloader + ) + for p1, p2 in zip(model.unwrap().parameters(), origin_model.unwrap().parameters()): + assert_close(p1.to(p2.dtype), p2, atol=1e-2, rtol=1e-2) + + def run_dist(rank, world_size, port, early_stop: bool = True): # init dist env colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") check_3d_plugin(early_stop=early_stop) + run_grad_acc_test() @rerun_if_address_is_in_use() diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py index 9952e41e5..17dfa3a18 100644 --- a/tests/test_booster/test_plugin/test_gemini_plugin.py +++ b/tests/test_booster/test_plugin/test_gemini_plugin.py @@ -12,7 +12,13 @@ from colossalai.fx import is_compatible_with_meta from colossalai.lazy.lazy_init import LazyInitContext from colossalai.nn.optimizer import HybridAdam from colossalai.tensor.colo_parameter import ColoParameter -from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing import ( + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + skip_if_not_enough_gpus, + spawn, +) from tests.kit.model_zoo import COMMON_MODELS, IS_FAST_TEST, model_zoo @@ -172,6 +178,7 @@ def test_gemini_plugin(early_stop: bool = True): @pytest.mark.largedist +@skip_if_not_enough_gpus(8) @rerun_if_address_is_in_use() def test_gemini_plugin_3d(early_stop: bool = True): spawn(run_dist, 8, early_stop=early_stop) diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py index bcdcc1470..861fa0131 100644 --- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py +++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py @@ -5,13 +5,13 @@ import torch.distributed as dist from torch.optim import Adam import colossalai -import colossalai.utils.device as device_utils +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import LowLevelZeroPlugin # from colossalai.nn.optimizer import HybridAdam -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from tests.kit.model_zoo import model_zoo, IS_FAST_TEST, COMMON_MODELS +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import COMMON_MODELS, IS_FAST_TEST, model_zoo # These models are not compatible with AMP _AMP_ERR_MODELS = ["timm_convit", "deepfm_interactionarch"] @@ -21,8 +21,9 @@ _LOW_LEVEL_ZERO_ERR_MODELS = ["dlrm_interactionarch"] _STUCK_MODELS = ["transformers_albert_for_multiple_choice"] +@clear_cache_before_run() def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: - device = device_utils.get_current_device() + device = get_accelerator().get_current_device() try: plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=2**5) booster = Booster(plugin=plugin) @@ -74,7 +75,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True): continue err = run_fn(stage, model_fn, data_gen_fn, output_transform_fn) - device_utils.empty_cache() + get_accelerator().empty_cache() if err is None: passed_models.append(name) diff --git a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py index fa32feb2f..e785843fb 100644 --- a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py +++ b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py @@ -10,10 +10,11 @@ import colossalai from colossalai.booster import Booster from colossalai.booster.plugin import TorchDDPPlugin from colossalai.interface import OptimizerWrapper -from colossalai.testing import rerun_if_address_is_in_use, spawn -from tests.kit.model_zoo import model_zoo, IS_FAST_TEST, COMMON_MODELS +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import COMMON_MODELS, IS_FAST_TEST, model_zoo +@clear_cache_before_run() def run_fn(model_fn, data_gen_fn, output_transform_fn): plugin = TorchDDPPlugin() booster = Booster(plugin=plugin) diff --git a/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py b/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py index 8a14d7cf8..f69807046 100644 --- a/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py +++ b/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py @@ -11,11 +11,12 @@ if version.parse(torch.__version__) >= version.parse("1.12.0"): from colossalai.booster.plugin import TorchFSDPPlugin from colossalai.interface import OptimizerWrapper -from colossalai.testing import rerun_if_address_is_in_use, spawn -from tests.kit.model_zoo import model_zoo, IS_FAST_TEST, COMMON_MODELS +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import COMMON_MODELS, IS_FAST_TEST, model_zoo # test basic fsdp function +@clear_cache_before_run() def run_fn(model_fn, data_gen_fn, output_transform_fn): plugin = TorchFSDPPlugin() booster = Booster(plugin=plugin) @@ -40,12 +41,18 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn): optimizer.clip_grad_by_norm(1.0) optimizer.step() + del model + del optimizer + del criterion + del booster + del plugin + def check_torch_fsdp_plugin(): if IS_FAST_TEST: registry = model_zoo.get_sub_registry(COMMON_MODELS) else: - registry = model_zoo + registry = model_zoo.get_sub_registry("transformers_gptj") for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in registry.items(): if any( @@ -59,6 +66,7 @@ def check_torch_fsdp_plugin(): ] ): continue + print(name) run_fn(model_fn, data_gen_fn, output_transform_fn) torch.cuda.empty_cache() @@ -73,3 +81,7 @@ def run_dist(rank, world_size, port): @rerun_if_address_is_in_use() def test_torch_fsdp_plugin(): spawn(run_dist, 2) + + +if __name__ == "__main__": + test_torch_fsdp_plugin() diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index 49fd85ffb..708a1906b 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -7,7 +7,6 @@ from transformers import LlamaForCausalLM from utils import shared_tempdir import colossalai -from colossalai.testing import skip_if_not_enough_gpus from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin from colossalai.lazy import LazyInitContext @@ -17,6 +16,7 @@ from colossalai.testing import ( clear_cache_before_run, parameterize, rerun_if_address_is_in_use, + skip_if_not_enough_gpus, spawn, ) from tests.kit.model_zoo import model_zoo @@ -52,7 +52,12 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b bert_model.config.save_pretrained(save_directory=pretrained_path) extra_dp_size = dist.get_world_size() // (zero_size * tp_size) - plugin = GeminiPlugin(**placement_config, tp_size=tp_size, enable_all_optimization=enable_all_optimization, extra_dp_size=extra_dp_size) + plugin = GeminiPlugin( + **placement_config, + tp_size=tp_size, + enable_all_optimization=enable_all_optimization, + extra_dp_size=extra_dp_size, + ) booster = Booster(plugin=plugin) bert_model, _, _, _, _ = booster.boost(bert_model) model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2 @@ -78,7 +83,14 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha criterion = lambda x: x.mean() enable_all_optimization = True if tp_size > 1 else False extra_dp_size = dist.get_world_size() // (zero_size * tp_size) - plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14), tp_size=tp_size, extra_dp_size=extra_dp_size, enable_all_optimization=enable_all_optimization) + plugin = GeminiPlugin( + **placement_config, + precision="fp16", + initial_scale=(2**14), + tp_size=tp_size, + extra_dp_size=extra_dp_size, + enable_all_optimization=enable_all_optimization, + ) booster = Booster(plugin=plugin) model = model_fn() @@ -161,8 +173,13 @@ def run_dist(rank, world_size, port): def test_gemini_ckpIO(): spawn(run_dist, 4) + @pytest.mark.largedist @skip_if_not_enough_gpus(min_gpus=8) @rerun_if_address_is_in_use() def test_gemini_ckpIO_3d(): - spawn(run_dist, 8) \ No newline at end of file + spawn(run_dist, 8) + + +if __name__ == "__main__": + test_gemini_ckpIO() diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py index db3c56da8..a42b550cd 100644 --- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -38,11 +38,11 @@ else: ] -@clear_cache_before_run() @parameterize("shard", [True, False]) @parameterize("model_name", ["transformers_llama_for_casual_lm"]) @parameterize("size_per_shard", [32]) @parameterize("test_config", TEST_CONFIGS) +@clear_cache_before_run() def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict): (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next( iter(model_zoo.get_sub_registry(model_name).values()) @@ -104,30 +104,32 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf # Check whether the loaded model & optimizer works smoothly. model.train() new_model.train() + data_for_shard = data_gen_fn() + data_for_origin = data_gen_fn() if booster.plugin.stage_manager is not None: booster.execute_pipeline( - _preprocess_data(data), model, _criterion, optimizer, return_loss=True, return_outputs=False + _preprocess_data(data_for_shard), model, _criterion, optimizer, return_loss=True, return_outputs=False ) booster.execute_pipeline( - _preprocess_data(data), new_model, _criterion, new_optimizer, return_loss=True, return_outputs=False + _preprocess_data(data_for_origin), + new_model, + _criterion, + new_optimizer, + return_loss=True, + return_outputs=False, ) else: - old_model_loss = criterion(model(**_preprocess_data(data))) + old_model_loss = criterion(model(**_preprocess_data(data_for_shard))) optimizer.backward(old_model_loss) - new_model_loss = criterion(new_model(**_preprocess_data(data))) + new_model_loss = criterion(new_model(**_preprocess_data(data_for_origin))) new_optimizer.backward(new_model_loss) optimizer.step() new_optimizer.step() # Check updated weights. - stage_manager = booster.plugin.stage_manager - - if stage_manager is None or stage_manager.is_first_stage(): - assert_close_loose(model.unwrap().wte.weight.data, new_model.unwrap().wte.weight.data, atol=5e-3, rtol=5e-3) - assert_close_loose( - model.unwrap().h[0].mlp.c_fc.weight.data, new_model.unwrap().h[0].mlp.c_fc.weight.data, atol=5e-3, rtol=5e-3 - ) + for p1, p2 in zip(model.unwrap().parameters(), new_model.unwrap().parameters()): + assert_close_loose(p1, p2, atol=5e-3, rtol=5e-3) dist.barrier() Randomizer.reset_index() @@ -145,3 +147,7 @@ def run_dist(rank, world_size, port): @rerun_if_address_is_in_use() def test_hybrid_ckpIO(world_size): spawn(run_dist, world_size) + + +if __name__ == "__main__": + test_hybrid_ckpIO(4) diff --git a/tests/test_infer_ops/triton/test_llama_act_combine.py b/tests/test_infer_ops/triton/test_llama_act_combine.py deleted file mode 100644 index 5341aa35a..000000000 --- a/tests/test_infer_ops/triton/test_llama_act_combine.py +++ /dev/null @@ -1,56 +0,0 @@ -import pytest -import torch -from packaging import version -from torch import nn - -from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine - -try: - import triton - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') - -BATCH_SIZE = 4 -SEQ_LEN = 16 -HIDDEN_SIZE = 32 - - -def SwiGLU(x): - """Gated linear unit activation function. - Args: - x : input array - axis: the axis along which the split should be computed (default: -1) - """ - size = x.shape[-1] - assert size % 2 == 0, "axis size must be divisible by 2" - x1, x2 = torch.split(x, size // 2, -1) - return x1 * (x2 * torch.sigmoid(x2.to(torch.float32)).to(x.dtype)) - - -@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton") -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) -def test_llama_act_combine(dtype: str): - x_gate = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE * 2, dtype=dtype).cuda() - x_gate_torch = nn.Parameter(x_gate.detach().clone()) - x_gate_kernel = nn.Parameter(x_gate.detach().clone()) - x_up = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, dtype=dtype).cuda() - x_up_torch = nn.Parameter(x_up.detach().clone()) - x_up_kernel = nn.Parameter(x_up.detach().clone()) - - torch_out = SwiGLU(x_gate_torch) * x_up_torch - kernel_out = LlamaActCombine.apply(x_gate_kernel, x_up_kernel) - atol = 1e-5 if dtype == torch.float32 else 5e-2 - assert torch.allclose(torch_out, kernel_out, atol=atol) - - torch_out.mean().backward() - kernel_out.mean().backward() - assert all(grad is not None for grad in [x_gate_torch.grad, x_up_torch.grad, x_gate_kernel.grad, x_up_kernel.grad]) - assert torch.allclose(x_gate_torch.grad, x_gate_kernel.grad, atol=atol) - assert torch.allclose(x_up_torch.grad, x_up_kernel.grad, atol=atol) - - -if __name__ == '__main__': - test_llama_act_combine(torch.float16) diff --git a/tests/test_infer_ops/triton/test_softmax.py b/tests/test_infer_ops/triton/test_softmax.py deleted file mode 100644 index 43b9c0929..000000000 --- a/tests/test_infer_ops/triton/test_softmax.py +++ /dev/null @@ -1,36 +0,0 @@ -import pytest -import torch -from packaging import version -from torch import nn - -try: - from colossalai.kernel.triton.softmax import softmax - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") - - -@pytest.mark.skipif( - not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" -) -def test_softmax_op(): - data_samples = [ - torch.randn((3, 4, 5, 32), device="cuda", dtype=torch.float32), - torch.randn((320, 320, 78), device="cuda", dtype=torch.float32), - torch.randn((2345, 4, 5, 64), device="cuda", dtype=torch.float16), - ] - - for data in data_samples: - module = nn.Softmax(dim=-1) - data_torch_out = module(data) - data_triton_out = softmax(data) - check = torch.allclose(data_torch_out.cpu(), data_triton_out.cpu(), rtol=1e-3, atol=1e-3) - assert check is True, "softmax outputs from triton and torch are not matched" - - -if __name__ == "__main__": - test_softmax_op() diff --git a/tests/test_lazy/test_models.py b/tests/test_lazy/test_models.py index ee50e5b61..d0c4cd0a7 100644 --- a/tests/test_lazy/test_models.py +++ b/tests/test_lazy/test_models.py @@ -1,14 +1,19 @@ import pytest from lazy_init_utils import SUPPORT_LAZY, check_lazy_init -from tests.kit.model_zoo import model_zoo, IS_FAST_TEST, COMMON_MODELS +from tests.kit.model_zoo import COMMON_MODELS, IS_FAST_TEST, model_zoo @pytest.mark.skipif(not SUPPORT_LAZY, reason="requires torch >= 1.12.0") -@pytest.mark.parametrize("subset", [COMMON_MODELS] if IS_FAST_TEST else ["torchvision", "diffusers", "timm", "transformers", "torchaudio", "deepfm", "dlrm"]) +@pytest.mark.parametrize( + "subset", + [COMMON_MODELS] + if IS_FAST_TEST + else ["torchvision", "diffusers", "timm", "transformers", "torchaudio", "deepfm", "dlrm"], +) @pytest.mark.parametrize("default_device", ["cpu", "cuda"]) def test_torchvision_models_lazy_init(subset, default_device): - sub_model_zoo = model_zoo.get_sub_registry(subset) + sub_model_zoo = model_zoo.get_sub_registry(subset, allow_empty=True) for name, entry in sub_model_zoo.items(): # TODO(ver217): lazy init does not support weight norm, skip these models if name in ("torchaudio_wav2vec2_base", "torchaudio_hubert_base") or name.startswith( diff --git a/tests/test_legacy/test_comm/test_comm.py b/tests/test_legacy/test_comm/test_comm.py index 7d2c81972..079022e93 100644 --- a/tests/test_legacy/test_comm/test_comm.py +++ b/tests/test_legacy/test_comm/test_comm.py @@ -2,12 +2,12 @@ import pytest import torch import torch.distributed as dist +from colossalai.accelerator import get_accelerator from colossalai.legacy.communication import all_gather, all_reduce, reduce_scatter from colossalai.legacy.context import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.initialize import launch from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device CONFIG = dict(parallel=dict(data=8, pipeline=1, tensor=dict(mode=None, size=1))) @@ -16,7 +16,7 @@ SIZE = 8 def check_all_gather(): tensor = torch.tensor([dist.get_rank() * SIZE + j for j in range(SIZE)]) - tensor = tensor.to(get_current_device()) + tensor = tensor.to(get_accelerator().get_current_device()) print("Before: Rank {0} - {1}".format(dist.get_rank(), tensor)) tensor, op = all_gather(tensor, 0, ParallelMode.GLOBAL, async_op=True) print("After: Rank {0} - {1}".format(dist.get_rank(), tensor)) @@ -27,7 +27,7 @@ def check_all_gather(): def check_reduce_scatter(): tensor = torch.tensor([dist.get_rank() * SIZE + j for j in range(SIZE)]) - tensor = tensor.to(get_current_device()) + tensor = tensor.to(get_accelerator().get_current_device()) print("Before: Rank {0} - {1}".format(dist.get_rank(), tensor)) tensor, op = reduce_scatter(tensor, 0, ParallelMode.GLOBAL, async_op=True) print("After: Rank {0} - {1}".format(dist.get_rank(), tensor)) @@ -38,7 +38,7 @@ def check_reduce_scatter(): def check_all_reduce(): tensor = torch.tensor([dist.get_rank() * SIZE + j for j in range(SIZE)]) - tensor = tensor.to(get_current_device()) + tensor = tensor.to(get_accelerator().get_current_device()) print("Before: Rank {0} - {1}".format(dist.get_rank(), tensor)) tensor, op = all_reduce(tensor, ParallelMode.GLOBAL, async_op=True) print("After: Rank {0} - {1}".format(dist.get_rank(), tensor)) diff --git a/tests/test_legacy/test_layers/test_1d/checks_1d/check_layer_1d.py b/tests/test_legacy/test_layers/test_1d/checks_1d/check_layer_1d.py index 8a9a73d65..f09df9253 100644 --- a/tests/test_legacy/test_layers/test_1d/checks_1d/check_layer_1d.py +++ b/tests/test_legacy/test_layers/test_1d/checks_1d/check_layer_1d.py @@ -2,6 +2,7 @@ import torch import torch.distributed as dist from torch.nn import Parameter +from colossalai.accelerator import get_accelerator from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.global_variables import tensor_parallel_env as env @@ -16,13 +17,12 @@ from colossalai.legacy.nn import ( VocabParallelEmbedding1D, ) from colossalai.legacy.utils import print_rank_0 -from colossalai.utils import get_current_device from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal def check_linear_col(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE OUTPUT_SIZE = 2 * HIDDEN_SIZE @@ -68,7 +68,7 @@ def check_linear_col(): print_rank_0("linear_col forward: pass") grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device()) dist.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=-1)[i] grad = grad.clone() @@ -91,7 +91,7 @@ def check_linear_col(): def check_linear_row(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE OUTPUT_SIZE = 2 * HIDDEN_SIZE @@ -137,7 +137,7 @@ def check_linear_row(): print_rank_0("linear_row forward: pass") grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device()) dist.broadcast(grad_master, src=0) grad = grad_master.clone() out.backward(grad) @@ -159,7 +159,7 @@ def check_linear_row(): def check_embed(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) @@ -201,7 +201,7 @@ def check_embed(): def check_vocab_parallel_embed(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) @@ -243,7 +243,7 @@ def check_vocab_parallel_embed(): def check_classifier_no_given_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) @@ -309,7 +309,7 @@ def check_classifier_no_given_weight(): def check_vocab_parallel_classifier_no_given_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) @@ -369,7 +369,7 @@ def check_vocab_parallel_classifier_no_given_weight(): def check_classifier_given_embed_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) @@ -420,7 +420,7 @@ def check_classifier_given_embed_weight(): def check_vocab_parallel_classifier_given_embed_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) @@ -472,7 +472,7 @@ def check_vocab_parallel_classifier_given_embed_weight(): def check_vocab_parallel_loss(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) @@ -508,7 +508,7 @@ def check_vocab_parallel_loss(): @torch.no_grad() def check_linear_row_stream_inference(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE OUTPUT_SIZE = 2 * HIDDEN_SIZE diff --git a/tests/test_legacy/test_layers/test_2d/checks_2d/check_layer_2d.py b/tests/test_legacy/test_layers/test_2d/checks_2d/check_layer_2d.py index 0bbc72eca..78bd407b9 100644 --- a/tests/test_legacy/test_layers/test_2d/checks_2d/check_layer_2d.py +++ b/tests/test_legacy/test_layers/test_2d/checks_2d/check_layer_2d.py @@ -1,5 +1,6 @@ import torch +from colossalai.accelerator import get_accelerator from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn import ( @@ -16,13 +17,12 @@ from colossalai.legacy.nn import ( VocabParallelEmbedding2D, ) from colossalai.legacy.utils import print_rank_0 -from colossalai.utils import get_current_device from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal def check_linear(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE OUTPUT_SIZE = HIDDEN_SIZE @@ -74,7 +74,7 @@ def check_linear(): print_rank_0("linear forward: pass") grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] grad = torch.chunk(grad, DEPTH, dim=-1)[j] @@ -103,7 +103,7 @@ def check_linear(): def check_layernorm(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE EPS = 1e-12 @@ -139,7 +139,7 @@ def check_layernorm(): print_rank_0("layer norm forward: pass") grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] grad = torch.chunk(grad, DEPTH, dim=-1)[j] @@ -154,7 +154,7 @@ def check_layernorm(): def check_embed(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) @@ -201,7 +201,7 @@ def check_embed(): def check_patch_embed(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) @@ -274,7 +274,7 @@ def check_patch_embed(): def check_vocab_parallel_embed(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) @@ -321,7 +321,7 @@ def check_vocab_parallel_embed(): def check_classifier_no_given_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE OUTPUT_SIZE = NUM_CLASSES @@ -371,7 +371,7 @@ def check_classifier_no_given_weight(): print_rank_0("classifier (no given weight) forward: pass") grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] # grad = torch.chunk(grad, DEPTH, dim=-1)[j] @@ -399,7 +399,7 @@ def check_classifier_no_given_weight(): def check_vocab_parallel_classifier_no_given_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) @@ -467,7 +467,7 @@ def check_vocab_parallel_classifier_no_given_weight(): def check_classifier_given_embed_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) @@ -519,7 +519,7 @@ def check_classifier_given_embed_weight(): def check_vocab_parallel_classifier_given_embed_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) @@ -573,7 +573,7 @@ def check_vocab_parallel_classifier_given_embed_weight(): def check_loss(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) @@ -608,7 +608,7 @@ def check_loss(): def check_vocab_parallel_loss(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) @@ -645,7 +645,7 @@ def check_vocab_parallel_loss(): # def check_attention(): -# device = get_current_device() +# device = get_accelerator().get_current_device() # dtype = torch.float32 # INPUT_SIZE = HIDDEN_SIZE # NUM_ATTENTION_HEADS = 2 @@ -683,7 +683,7 @@ def check_vocab_parallel_loss(): # print_rank_0('self attention backward: pass') # def check_mlp(): -# device = get_current_device() +# device = get_accelerator().get_current_device() # dtype = torch.float32 # INPUT_SIZE = HIDDEN_SIZE @@ -716,7 +716,7 @@ def check_vocab_parallel_loss(): # print_rank_0('mlp backward: pass') # def check_transformerlayer(): -# device = get_current_device() +# device = get_accelerator().get_current_device() # dtype = torch.float32 # INPUT_SIZE = HIDDEN_SIZE # NUM_ATTENTION_HEADS = 2 diff --git a/tests/test_legacy/test_layers/test_2d/checks_2d/check_operation_2d.py b/tests/test_legacy/test_layers/test_2d/checks_2d/check_operation_2d.py index 9c126cefe..4506cfee6 100644 --- a/tests/test_legacy/test_layers/test_2d/checks_2d/check_operation_2d.py +++ b/tests/test_legacy/test_layers/test_2d/checks_2d/check_operation_2d.py @@ -3,11 +3,11 @@ import torch +from colossalai.accelerator import get_accelerator from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.parallel_2d._operation import Matmul_AB_2D, Matmul_ABT_2D, Matmul_ATB_2D from colossalai.legacy.utils import print_rank_0 -from colossalai.utils import get_current_device from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, SEQ_LENGTH, check_equal @@ -27,7 +27,7 @@ def check_AB(): i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) - A_master = torch.randn(A_shape, dtype=dtype, device=get_current_device()) + A_master = torch.randn(A_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(A_master, src=0) A = torch.chunk(A_master, DEPTH, dim=0)[i] A = torch.chunk(A, DEPTH, dim=-1)[j] @@ -35,7 +35,7 @@ def check_AB(): A.requires_grad = True B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE) - B_master = torch.randn(B_shape, dtype=dtype, device=get_current_device()) + B_master = torch.randn(B_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(B_master, src=0) B = torch.chunk(B_master, DEPTH, dim=0)[i] B = torch.chunk(B, DEPTH, dim=-1)[j] @@ -72,7 +72,7 @@ def check_AB(): print_rank_0("AB forward: pass") grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] grad = torch.chunk(grad, DEPTH, dim=-1)[j] @@ -105,7 +105,7 @@ def check_ABT(): tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) dtype = torch.float - device = get_current_device() + device = get_accelerator().get_current_device() j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) @@ -184,7 +184,7 @@ def check_ATB(): ) tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) diff --git a/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py index 283e7f683..914607614 100644 --- a/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py +++ b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py @@ -1,6 +1,7 @@ import torch from torch.nn import Parameter +from colossalai.accelerator import get_accelerator from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn import ( @@ -17,13 +18,12 @@ from colossalai.legacy.nn import ( VocabParallelEmbedding2p5D, ) from colossalai.legacy.utils import print_rank_0 -from colossalai.utils import get_current_device from .common import * def check_linear(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE OUTPUT_SIZE = 2 * HIDDEN_SIZE @@ -76,7 +76,7 @@ def check_linear(): print_rank_0("linear forward: pass") grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] @@ -104,7 +104,7 @@ def check_linear(): def check_layernorm(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE EPS = 1e-12 @@ -141,7 +141,7 @@ def check_layernorm(): print_rank_0("layer norm forward: pass") grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] @@ -156,7 +156,7 @@ def check_layernorm(): def check_embed(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) @@ -204,7 +204,7 @@ def check_embed(): def check_patch_embed(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) @@ -278,7 +278,7 @@ def check_patch_embed(): def check_vocab_parallel_embed(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) @@ -326,7 +326,7 @@ def check_vocab_parallel_embed(): def check_classifier_no_given_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE OUTPUT_SIZE = NUM_CLASSES @@ -377,7 +377,7 @@ def check_classifier_no_given_weight(): print_rank_0("classifier (no given weight) forward: pass") grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] # grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] @@ -405,7 +405,7 @@ def check_classifier_no_given_weight(): def check_vocab_parallel_classifier_no_given_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) @@ -472,7 +472,7 @@ def check_vocab_parallel_classifier_no_given_weight(): def check_classifier_given_embed_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) @@ -524,7 +524,7 @@ def check_classifier_given_embed_weight(): def check_vocab_parallel_classifier_given_embed_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) @@ -578,7 +578,7 @@ def check_vocab_parallel_classifier_given_embed_weight(): def check_loss(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) @@ -613,7 +613,7 @@ def check_loss(): def check_vocab_parallel_loss(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) @@ -650,7 +650,7 @@ def check_vocab_parallel_loss(): # def check_attention(): -# device = get_current_device() +# device = get_accelerator().get_current_device() # dtype = torch.float32 # INPUT_SIZE = HIDDEN_SIZE # NUM_ATTENTION_HEADS = 2 @@ -689,7 +689,7 @@ def check_vocab_parallel_loss(): # print_rank_0('self attention backward: pass') # def check_mlp(): -# device = get_current_device() +# device = get_accelerator().get_current_device() # dtype = torch.float32 # INPUT_SIZE = HIDDEN_SIZE @@ -725,7 +725,7 @@ def check_vocab_parallel_loss(): # print_rank_0('mlp backward: pass') # def check_transformerlayer(): -# device = get_current_device() +# device = get_accelerator().get_current_device() # dtype = torch.float32 # INPUT_SIZE = HIDDEN_SIZE # NUM_ATTENTION_HEADS = 2 diff --git a/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py index 992bd6107..91a15c81d 100644 --- a/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py +++ b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py @@ -1,10 +1,10 @@ import torch +from colossalai.accelerator import get_accelerator from colossalai.legacy.context import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.parallel_2p5d._operation import Matmul_AB_2p5D, Matmul_ABT_2p5D, Matmul_ATB_2p5D from colossalai.legacy.utils import print_rank_0 -from colossalai.utils import get_current_device from .common import * @@ -25,7 +25,7 @@ def check_AB(): k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) - A_master = torch.randn(A_shape, dtype=dtype, device=get_current_device()) + A_master = torch.randn(A_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(A_master, src=0) A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i] A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j] @@ -33,7 +33,7 @@ def check_AB(): A.requires_grad = True B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE) - B_master = torch.randn(B_shape, dtype=dtype, device=get_current_device()) + B_master = torch.randn(B_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(B_master, src=0) B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[i] B = torch.chunk(B, TESSERACT_DIM, dim=-1)[j] @@ -70,7 +70,7 @@ def check_AB(): print_rank_0("AB forward: pass") grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] @@ -103,7 +103,7 @@ def check_ABT(): tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) dtype = torch.float - device = get_current_device() + device = get_accelerator().get_current_device() i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) @@ -184,7 +184,7 @@ def check_ATB(): ) tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) diff --git a/tests/test_legacy/test_layers/test_3d/checks_3d/check_layer_3d.py b/tests/test_legacy/test_layers/test_3d/checks_3d/check_layer_3d.py index a4a4ae9a5..f9f19a17b 100644 --- a/tests/test_legacy/test_layers/test_3d/checks_3d/check_layer_3d.py +++ b/tests/test_legacy/test_layers/test_3d/checks_3d/check_layer_3d.py @@ -5,6 +5,7 @@ import time import torch +from colossalai.accelerator import get_accelerator from colossalai.legacy.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D from colossalai.legacy.core import global_context from colossalai.legacy.nn import ( @@ -23,7 +24,6 @@ from colossalai.legacy.nn import ( from colossalai.legacy.nn.layer.parallel_3d._utils import get_parallel_mode_from_env from colossalai.legacy.utils import print_rank_0 from colossalai.logging import get_dist_logger -from colossalai.utils import get_current_device from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal @@ -31,7 +31,7 @@ from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_L def check_linear(): rank = torch.distributed.get_rank() logger = get_dist_logger() - device = get_current_device() + device = get_accelerator().get_current_device() INPUT_SIZE = HIDDEN_SIZE OUTPUT_SIZE = 2 * HIDDEN_SIZE @@ -84,7 +84,7 @@ def check_linear(): logger.info("Rank {} linear forward: {}".format(rank, check_equal(out, C))) grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, device=get_current_device()) + grad_master = torch.randn(grad_shape, device=get_accelerator().get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] grad = torch.chunk(grad, DEPTH, dim=-1)[j] @@ -119,7 +119,7 @@ def check_linear(): def check_layernorm(): rank = torch.distributed.get_rank() logger = get_dist_logger() - device = get_current_device() + device = get_accelerator().get_current_device() INPUT_SIZE = HIDDEN_SIZE input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) @@ -206,7 +206,7 @@ def check_layernorm(): def check_classifier_no_given_weight(): rank = torch.distributed.get_rank() logger = get_dist_logger() - device = get_current_device() + device = get_accelerator().get_current_device() INPUT_SIZE = HIDDEN_SIZE input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) @@ -258,7 +258,7 @@ def check_classifier_no_given_weight(): logger.info("Rank {} classifier (no given weight) forward: {}".format(rank, check_equal(out, C))) grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, device=get_current_device()) + grad_master = torch.randn(grad_shape, device=get_accelerator().get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] grad = torch.chunk(grad, DEPTH, dim=0)[j] @@ -306,7 +306,7 @@ def check_classifier_no_given_weight(): def check_vocab_parallel_classifier_no_given_weight(): rank = torch.distributed.get_rank() logger = get_dist_logger() - device = get_current_device() + device = get_accelerator().get_current_device() INPUT_SIZE = HIDDEN_SIZE input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) @@ -413,7 +413,7 @@ def check_vocab_parallel_classifier_no_given_weight(): def check_classifier_given_embed_weight(): rank = torch.distributed.get_rank() logger = get_dist_logger() - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) @@ -463,7 +463,7 @@ def check_classifier_given_embed_weight(): logger.info("Rank {} classifier (given embed weight) forward: {}".format(rank, check_equal(out, C))) grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] grad = torch.chunk(grad, DEPTH, dim=0)[j] @@ -497,7 +497,7 @@ def check_classifier_given_embed_weight(): def check_vocab_parallel_classifier_given_embed_weight(): rank = torch.distributed.get_rank() logger = get_dist_logger() - device = get_current_device() + device = get_accelerator().get_current_device() input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) @@ -580,7 +580,7 @@ def check_vocab_parallel_classifier_given_embed_weight(): def check_patch_embed(): rank = torch.distributed.get_rank() - device = get_current_device() + device = get_accelerator().get_current_device() logger = get_dist_logger() torch.float32 @@ -678,7 +678,7 @@ def check_patch_embed(): def check_embed(): rank = torch.distributed.get_rank() - device = get_current_device() + device = get_accelerator().get_current_device() logger = get_dist_logger() torch.float32 @@ -746,7 +746,7 @@ def check_embed(): def check_vocab_parallel_embed(): rank = torch.distributed.get_rank() - device = get_current_device() + device = get_accelerator().get_current_device() logger = get_dist_logger() torch.float32 @@ -823,7 +823,7 @@ def check_vocab_parallel_embed(): def check_loss(): rank = torch.distributed.get_rank() logger = get_dist_logger() - device = get_current_device() + device = get_accelerator().get_current_device() input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) @@ -876,7 +876,7 @@ def check_loss(): def check_vocab_parallel_loss(): rank = torch.distributed.get_rank() logger = get_dist_logger() - device = get_current_device() + device = get_accelerator().get_current_device() torch.float32 input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) diff --git a/tests/test_legacy/test_layers/test_sequence/checks_seq/check_layer_seq.py b/tests/test_legacy/test_layers/test_sequence/checks_seq/check_layer_seq.py index aa4d5d6ce..f4ad0d6d1 100644 --- a/tests/test_legacy/test_layers/test_sequence/checks_seq/check_layer_seq.py +++ b/tests/test_legacy/test_layers/test_sequence/checks_seq/check_layer_seq.py @@ -1,9 +1,9 @@ import torch +from colossalai.accelerator import get_accelerator from colossalai.legacy.context import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn import TransformerSelfAttentionRing -from colossalai.utils import get_current_device def check_selfattention(): @@ -13,10 +13,10 @@ def check_selfattention(): HIDDEN_SIZE = 16 layer = TransformerSelfAttentionRing(16, 8, 8, 0.1) - layer = layer.to(get_current_device()) + layer = layer.to(get_accelerator().get_current_device()) - hidden_states = torch.rand(SUB_SEQ_LENGTH, BATCH, HIDDEN_SIZE).to(get_current_device()) + hidden_states = torch.rand(SUB_SEQ_LENGTH, BATCH, HIDDEN_SIZE).to(get_accelerator().get_current_device()) attention_mask = torch.randint(low=0, high=2, size=(BATCH, 1, 1, 1, SUB_SEQ_LENGTH * WORLD_SIZE)).to( - get_current_device() + get_accelerator().get_current_device() ) layer(hidden_states, attention_mask) diff --git a/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py b/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py index a5a2d3857..cab111358 100644 --- a/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py +++ b/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py @@ -5,6 +5,7 @@ import pytest import torch import torch.distributed as dist +from colossalai.accelerator import get_accelerator from colossalai.legacy.communication import ( recv_backward, recv_forward, @@ -18,7 +19,6 @@ from colossalai.legacy.core import global_context as gpc from colossalai.legacy.initialize import launch from colossalai.logging import get_dist_logger from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device BATCH_SIZE = 4 SEQ_LENGTH = 2 @@ -73,7 +73,7 @@ def check_forward_backward(output_tensor, output_grad, rank, logger): def check_comm(size, rank, prev_rank, next_rank, logger): dtype = torch.float32 - device = get_current_device() + device = get_accelerator().get_current_device() tensor_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) grad_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) tensor = torch.randn(tensor_shape, dtype=dtype, device=device) diff --git a/tests/test_legacy/test_utils/test_memory.py b/tests/test_legacy/test_utils/test_memory.py index 9df7cf75a..4993df4f3 100644 --- a/tests/test_legacy/test_utils/test_memory.py +++ b/tests/test_legacy/test_utils/test_memory.py @@ -1,15 +1,15 @@ import pytest import colossalai +from colossalai.accelerator import get_accelerator from colossalai.legacy.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction from colossalai.testing import spawn -from colossalai.utils.device import get_current_device def _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity(): - frac1 = colo_device_memory_capacity(get_current_device()) + frac1 = colo_device_memory_capacity(get_accelerator().get_current_device()) colo_set_process_memory_fraction(0.5) - frac2 = colo_device_memory_capacity(get_current_device()) + frac2 = colo_device_memory_capacity(get_accelerator().get_current_device()) assert frac2 * 2 == frac1 diff --git a/tests/test_legacy/test_utils/test_norm_gradient_clipping.py b/tests/test_legacy/test_utils/test_norm_gradient_clipping.py index b5f2be705..9975cc04f 100644 --- a/tests/test_legacy/test_utils/test_norm_gradient_clipping.py +++ b/tests/test_legacy/test_utils/test_norm_gradient_clipping.py @@ -4,12 +4,12 @@ from torch.nn.parameter import Parameter from torch.nn.utils import clip_grad_norm_ import colossalai +from colossalai.accelerator import get_accelerator from colossalai.legacy.tensor import ColoTensorSpec, ProcessGroup, distspec from colossalai.legacy.utils.common import clip_grad_norm from colossalai.logging import disable_existing_loggers from colossalai.tensor.colo_parameter import ColoParameter from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device def close(num: float, other: float, rtol: float = 1e-5, atol: float = 1e-8): @@ -36,7 +36,7 @@ def check_grad_equal(p: Parameter, colo_p: ColoParameter) -> None: @parameterize("norm_type", [2.0, 3.0, float("inf")]) def run_grad_clip_norm(world_size: int, dtype: torch.dtype, device: str, norm_type: float): print(f"{world_size}, {dtype}, {device}, {norm_type}") - cuda_device = get_current_device() + cuda_device = get_accelerator().get_current_device() devices = [cuda_device] * 4 if device == "cpu": devices = [torch.device("cpu")] * 4 diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_moe/test_grad_handler.py index 3fac62472..a349bc5a9 100644 --- a/tests/test_moe/test_grad_handler.py +++ b/tests/test_moe/test_grad_handler.py @@ -4,10 +4,10 @@ import torch.distributed as dist import torch.nn as nn import colossalai +from colossalai.accelerator import get_accelerator from colossalai.moe import SparseMLP from colossalai.moe.manager import MOE_MANAGER from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device from tests.test_moe.moe_utils import MoeGradientHandler BATCH_SIZE = 4 @@ -38,7 +38,7 @@ def run_test(rank, world_size, port): layer_list.append(moe_layer) model = nn.ModuleList(layer_list) - model = model.to(get_current_device()) + model = model.to(get_accelerator().get_current_device()) dist_dict = MOE_MANAGER.parallel_info_dict assert_equal_in_group(layer_list[0].experts.wi.data, dist_dict[1].dp_group) assert_equal_in_group(layer_list[0].experts.wo.data, dist_dict[1].dp_group) @@ -52,7 +52,7 @@ def run_test(rank, world_size, port): rank = dist.get_rank() torch.cuda.manual_seed(78 + rank) - data = torch.randn(BATCH_SIZE, DIM, device=get_current_device()) + data = torch.randn(BATCH_SIZE, DIM, device=get_accelerator().get_current_device()) grad = torch.randn_like(data) MOE_MANAGER.reset_loss() diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py index 255ec7444..62d61a3d4 100644 --- a/tests/test_moe/test_kernel.py +++ b/tests/test_moe/test_kernel.py @@ -3,10 +3,10 @@ import torch import torch.distributed as dist import colossalai +from colossalai.accelerator import get_accelerator from colossalai.moe import SparseMLP from colossalai.moe.manager import MOE_MANAGER from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device BATCH_SIZE = 4 NUM_EXPERTS = 4 @@ -28,7 +28,9 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f torch.manual_seed(rs + local_rank) # set each process has different random seed # get randomized data - tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True) + tokens = torch.randn( + BATCH_SIZE, hidden_size, dtype=data_type, device=get_accelerator().get_current_device(), requires_grad=True + ) layer = SparseMLP( hidden_size=hidden_size, @@ -37,7 +39,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f router_top_k=topk, router_capacity_factor_train=1.0, ) - layer = layer.to(get_current_device()) + layer = layer.to(get_accelerator().get_current_device()) if data_type == torch.float16: layer = layer.half() @@ -45,7 +47,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f layer.enable_kernel = False old_out = layer(tokens) ech = old_out.shape - grad = torch.randn(ech, device=get_current_device()) + grad = torch.randn(ech, device=get_accelerator().get_current_device()) old_out.backward(grad) # get gradient # save all results diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py index bd1103df3..8f51e1663 100644 --- a/tests/test_moe/test_moe_checkpoint.py +++ b/tests/test_moe/test_moe_checkpoint.py @@ -9,11 +9,11 @@ import torch.distributed as dist from transformers.models.llama import LlamaConfig import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.moe.manager import MOE_MANAGER from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device sys.path.append( os.path.join( @@ -28,7 +28,7 @@ OpenMoeForCausalLMPolicy = importlib.import_module("model.openmoe_policy").OpenM def data_gen_fn(batch_size: int = 2, max_length: int = 4, vocab_size: int = 20): - input_ids = torch.randint(0, vocab_size, (batch_size, max_length), device=get_current_device()) + input_ids = torch.randint(0, vocab_size, (batch_size, max_length), device=get_accelerator().get_current_device()) attention_mask = torch.ones_like(input_ids) return { "input_ids": input_ids, diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py index f87d4c792..74feeeb59 100644 --- a/tests/test_moe/test_moe_ep_tp.py +++ b/tests/test_moe/test_moe_ep_tp.py @@ -7,12 +7,12 @@ import torch import torch.distributed as dist import colossalai +from colossalai.accelerator import get_accelerator from colossalai.moe import SparseMLP from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import sync_moe_model_param from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device from tests.test_moe.moe_utils import MoeGradientHandler @@ -23,8 +23,9 @@ def sync_tp_from_local(tp_model: SparseMLP, local_model: SparseMLP, assert_grad_ tp_model (MoeModule) local_model (MoeModule) """ - for (tp_name, tp_param), (local_name, local_param) in \ - zip(tp_model.named_parameters(), local_model.named_parameters()): + for (tp_name, tp_param), (local_name, local_param) in zip( + tp_model.named_parameters(), local_model.named_parameters() + ): assert tp_name == local_name if not is_moe_tensor(tp_param): if assert_grad_flag: @@ -54,8 +55,7 @@ def sync_tp_from_ep(tp_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: tp_model (MoeModule) ep_model (MoeModule) """ - for (tp_name, tp_param), (ep_name, ep_param) in \ - zip(tp_model.named_parameters(), ep_model.named_parameters()): + for (tp_name, tp_param), (ep_name, ep_param) in zip(tp_model.named_parameters(), ep_model.named_parameters()): assert tp_name == ep_name if not is_moe_tensor(tp_param): if assert_grad_flag: @@ -97,8 +97,9 @@ def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_ local_model (MoeModule) ep_model (MoeModule) """ - for (local_name, local_param), (ep_name, ep_param) in \ - zip(local_model.named_parameters(), ep_model.named_parameters()): + for (local_name, local_param), (ep_name, ep_param) in zip( + local_model.named_parameters(), ep_model.named_parameters() + ): assert local_name == ep_name if "experts" not in local_name: if assert_grad_flag: @@ -141,14 +142,14 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2, - enable_hierarchical_comm=enable_hierarchical_comm + enable_hierarchical_comm=enable_hierarchical_comm, ) MOE_MANAGER.__init__() MOE_MANAGER.setup(parallel="TP") tp_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2) - ep_model = ep_model.to(get_current_device()) - tp_model = tp_model.to(get_current_device()) - local_model = local_model.to(get_current_device()) + ep_model = ep_model.to(get_accelerator().get_current_device()) + tp_model = tp_model.to(get_accelerator().get_current_device()) + local_model = local_model.to(get_accelerator().get_current_device()) # sync ep param sync_moe_model_param(ep_model) @@ -163,11 +164,11 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size tp_grad_handler = MoeGradientHandler(tp_model) rank = dist.get_rank() - input_data = torch.randn(batch_size, dim, device=get_current_device()) + input_data = torch.randn(batch_size, dim, device=get_accelerator().get_current_device()) micro_batch_size = batch_size // world_size index = rank * micro_batch_size # NOTE: ep & tp takes in sharded data for each process - shard_data = input_data.detach()[index:index + micro_batch_size] + shard_data = input_data.detach()[index : index + micro_batch_size] out_local = local_model(input_data) MOE_MANAGER.reset_loss() @@ -176,13 +177,15 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size out_ep = ep_model(shard_data) MOE_MANAGER.reset_loss() - assert torch.allclose(out_tp, out_ep, atol=1e-6), \ - f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_tp - out_ep))}" + assert torch.allclose( + out_tp, out_ep, atol=1e-6 + ), f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_tp - out_ep))}" try: - out_local_slice = out_local[index:index + micro_batch_size] - assert torch.allclose(out_ep, out_local_slice, atol=1e-6), \ - f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_ep - out_local_slice))}" - except AssertionError as e: + out_local_slice = out_local[index : index + micro_batch_size] + assert torch.allclose( + out_ep, out_local_slice, atol=1e-6 + ), f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_ep - out_local_slice))}" + except AssertionError: """ e.g., in local model, tokens = 4, capacity = 2, experts = 2, topk = 1 router yields [01] --> [0], [23] --> [1], this is valid as capacity is 2 @@ -193,8 +196,7 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size The same thing happens on router 1. And finally some tokens are dropped due to the sharded nature. """ warnings.warn( - "EP & TP may result in different behavior from local model. " - "Please check the comments for details." + "EP & TP may result in different behavior from local model. " "Please check the comments for details." ) out_local.mean().backward() @@ -208,10 +210,9 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size sync_tp_from_ep(tp_model, ep_model, assert_grad_flag=True) try: sync_local_from_ep(local_model, ep_model, assert_grad_flag=True) - except AssertionError as e: + except AssertionError: warnings.warn( - "EP & TP may result in different behavior from local model. " - "Please check the comments for details." + "EP & TP may result in different behavior from local model. " "Please check the comments for details." ) @@ -219,14 +220,17 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size @pytest.mark.parametrize("num_experts", [4, 64]) @pytest.mark.parametrize("batch_size", [16]) @pytest.mark.parametrize("dim", [64]) -@pytest.mark.parametrize("config", [ - {"enable_hierarchical_comm": False}, - {"enable_hierarchical_comm": True}, -]) +@pytest.mark.parametrize( + "config", + [ + {"enable_hierarchical_comm": False}, + {"enable_hierarchical_comm": True}, + ], +) @rerun_if_address_is_in_use() def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, config: Dict): spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, config=config) -if __name__ == '__main__': +if __name__ == "__main__": test_moe_ep_tp(num_experts=8, batch_size=32, dim=32) diff --git a/tests/test_moe/test_moe_group.py b/tests/test_moe/test_moe_group.py index 95c0e715d..2f08a335d 100644 --- a/tests/test_moe/test_moe_group.py +++ b/tests/test_moe/test_moe_group.py @@ -3,11 +3,11 @@ import torch.distributed as dist import torch.nn as nn import colossalai +from colossalai.accelerator import get_accelerator from colossalai.moe.experts import MLPExperts from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import sync_moe_model_param from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device HIDDEN_SIZE = 4 INTERMEDIATE_SIZE = 8 @@ -46,7 +46,7 @@ def run_moe_init(expert_parallel): assert dist.get_rank(parallel_info_dict[1].dp_group) == rank model = nn.ModuleList([exp0, exp1, exp2]) - model = model.to(get_current_device()) + model = model.to(get_accelerator().get_current_device()) sync_moe_model_param(model) # MOE experts layout success when ep_size = 1 diff --git a/tests/test_optimizer/test_adam_kernel.py b/tests/test_optimizer/test_adam_kernel.py index 6bbe3e4e8..6d932156a 100644 --- a/tests/test_optimizer/test_adam_kernel.py +++ b/tests/test_optimizer/test_adam_kernel.py @@ -8,7 +8,8 @@ import pytest import torch from torch import Tensor -from colossalai.utils import get_current_device, multi_tensor_applier +from colossalai.accelerator import get_accelerator +from colossalai.utils import multi_tensor_applier _FUSED_ALLOWED_P_G_TYPES = [ (torch.float, torch.half), @@ -64,9 +65,9 @@ class TorchAdamKernel(AdamKernel): class FusedAdamKernel(AdamKernel): def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None: super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw) - from colossalai.kernel.op_builder import FusedOptimBuilder + from colossalai.kernel.kernel_loader import FusedOptimizerLoader - fused_optim = FusedOptimBuilder().load() + fused_optim = FusedOptimizerLoader().load() self.fused_adam = fused_optim.multi_tensor_adam self.dummy_overflow_buf = torch.cuda.IntTensor([0]) @@ -90,9 +91,9 @@ class FusedAdamKernel(AdamKernel): class CPUAdamKernel(AdamKernel): def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None: super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw) - from colossalai.kernel.op_builder import CPUAdamBuilder + from colossalai.kernel.kernel_loader import CPUAdamLoader - cpu_optim = CPUAdamBuilder().load() + cpu_optim = CPUAdamLoader().load() self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, use_adamw) @@ -155,7 +156,9 @@ def test_fused_adam_kernel(adamw, weight_decay, p_dtype, g_dtype): rtol, atol = 1e-3, 1e-3 if p_dtype is torch.bfloat16 or g_dtype is torch.bfloat16: rtol, atol = 4e-3, 4e-3 - check_adam_kernel(FusedAdamKernel, adamw, weight_decay, p_dtype, g_dtype, get_current_device(), 3, rtol, atol) + check_adam_kernel( + FusedAdamKernel, adamw, weight_decay, p_dtype, g_dtype, get_accelerator().get_current_device(), 3, rtol, atol + ) @pytest.mark.parametrize("adamw", [False, True]) diff --git a/tests/test_pipeline/test_p2p_communication.py b/tests/test_pipeline/test_p2p_communication.py index caf6e6bbb..6f5e734b7 100644 --- a/tests/test_pipeline/test_p2p_communication.py +++ b/tests/test_pipeline/test_p2p_communication.py @@ -3,11 +3,11 @@ import torch import torch.distributed as dist import colossalai +from colossalai.accelerator import get_accelerator from colossalai.cluster import ProcessGroupMesh from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device WORLD_SIZE = 2 @@ -19,7 +19,7 @@ def check_p2p_communication(): rank = dist.get_rank() - tensor = torch.ones(1, device=get_current_device()) + tensor = torch.ones(1, device=get_accelerator().get_current_device()) data = [ "tensor", tensor, diff --git a/tests/test_pipeline/test_schedule/test_oneF_oneB.py b/tests/test_pipeline/test_schedule/test_oneF_oneB.py index 5f27be396..a08dc6d27 100644 --- a/tests/test_pipeline/test_schedule/test_oneF_oneB.py +++ b/tests/test_pipeline/test_schedule/test_oneF_oneB.py @@ -155,7 +155,7 @@ def run_dist( @pytest.mark.dist -@pytest.mark.parametrize("num_microbatch", [4, 12]) +@pytest.mark.parametrize("num_microbatch", [4, 6]) @pytest.mark.parametrize("batch_size", [12]) @pytest.mark.parametrize("world_size", [2, 4]) @rerun_if_address_is_in_use() diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 87e661802..62d4d1bf3 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -154,7 +154,7 @@ def run_forward_backward_with_hybrid_plugin( data = data_gen_fn() - if booster.plugin.enable_sequence_parallelism and booster.plugin.tp_size != 0: + if booster.plugin.shard_config.enable_sequence_parallelism and booster.plugin.tp_size != 0: seq_len = data["input_ids"].shape[-1] lcm = booster.plugin.tp_size * seq_len // math.gcd(booster.plugin.tp_size, seq_len) times = lcm // seq_len diff --git a/tests/test_utils/test_flash_attention.py b/tests/test_utils/test_flash_attention.py index a5c465ba0..3ec170004 100644 --- a/tests/test_utils/test_flash_attention.py +++ b/tests/test_utils/test_flash_attention.py @@ -4,13 +4,11 @@ import pytest import torch from einops import rearrange -from colossalai.kernel.cuda_native.mha.flash_attn_2 import HAS_FLASH_ATTN -from colossalai.kernel.cuda_native.mha.mem_eff_attn import HAS_MEM_EFF_ATTN +from colossalai.kernel.extensions.flash_attention import HAS_FLASH_ATTN, HAS_MEM_EFF_ATTN from colossalai.testing import clear_cache_before_run, parameterize if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN: - from colossalai.kernel.cuda_native import ColoAttention - from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType + from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention DTYPE = [torch.float16, torch.bfloat16, torch.float32] diff --git a/tests/test_zero/test_gemini/test_chunkv2.py b/tests/test_zero/test_gemini/test_chunkv2.py index 5977c706f..e4dc569b8 100644 --- a/tests/test_zero/test_gemini/test_chunkv2.py +++ b/tests/test_zero/test_gemini/test_chunkv2.py @@ -4,15 +4,15 @@ import torch.distributed as dist from torch.distributed.distributed_c10d import _get_default_group import colossalai +from colossalai.accelerator import get_accelerator from colossalai.tensor import ColoParameter from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device from colossalai.zero.gemini import TensorState from colossalai.zero.gemini.chunk import Chunk def dist_sum(x): - temp = torch.tensor([x], device=get_current_device()) + temp = torch.tensor([x], device=get_accelerator().get_current_device()) dist.all_reduce(temp) return temp.item() @@ -66,7 +66,7 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory): assert my_chunk.cpu_shard.size(0) == 1024 // world_size assert my_chunk.device_type == "cpu" assert my_chunk.can_move - my_chunk.shard_move(get_current_device()) + my_chunk.shard_move(get_accelerator().get_current_device()) else: assert my_chunk.cuda_global_chunk.size(0) == 1024 assert my_chunk.device_type == "cuda" diff --git a/tests/test_zero/test_gemini/test_fwd_bwd.py b/tests/test_zero/test_gemini/test_fwd_bwd.py index 21afff753..3a9742e01 100644 --- a/tests/test_zero/test_gemini/test_fwd_bwd.py +++ b/tests/test_zero/test_gemini/test_fwd_bwd.py @@ -5,11 +5,11 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close import colossalai +from colossalai.accelerator import get_accelerator from colossalai.legacy.amp import convert_to_apex_amp from colossalai.nn.optimizer import HybridAdam from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import set_seed -from colossalai.utils.device import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.kit.model_zoo import model_zoo, run_fwd_bwd @@ -47,7 +47,7 @@ def exam_gpt_fwd_bwd( use_grad_checkpoint: bool = False, master_weights: bool = True, ): - init_device = get_current_device() + init_device = get_accelerator().get_current_device() model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next( iter(model_zoo.get_sub_registry(model_name).values()) ) diff --git a/tests/test_zero/test_gemini/test_grad_accum.py b/tests/test_zero/test_gemini/test_grad_accum.py index 35323e516..36a803492 100644 --- a/tests/test_zero/test_gemini/test_grad_accum.py +++ b/tests/test_zero/test_gemini/test_grad_accum.py @@ -6,10 +6,10 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close import colossalai +from colossalai.accelerator import get_accelerator from colossalai.nn.optimizer import HybridAdam from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import set_seed -from colossalai.utils.device import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.kit.model_zoo import model_zoo, run_fwd @@ -53,7 +53,7 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module): def exam_gemini_grad_acc( placement_config, keep_gathered: bool, model_name: str, master_weights: bool, use_grad_checkpoint: bool ): - init_device = get_current_device() + init_device = get_accelerator().get_current_device() model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next( iter(model_zoo.get_sub_registry(model_name).values()) ) diff --git a/tests/test_zero/test_gemini/test_inference.py b/tests/test_zero/test_gemini/test_inference.py index 152bf2895..7f3c7176e 100644 --- a/tests/test_zero/test_gemini/test_inference.py +++ b/tests/test_zero/test_gemini/test_inference.py @@ -7,11 +7,11 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close import colossalai +from colossalai.accelerator import get_accelerator from colossalai.legacy.amp import convert_to_apex_amp from colossalai.nn.optimizer import HybridAdam from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import set_seed -from colossalai.utils.device import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.kit.model_zoo import model_zoo, run_fwd, run_fwd_bwd @@ -47,7 +47,9 @@ def multi_chunk_init(model: torch.nn.Module, placement_config: dict): def single_chunk_init(model: torch.nn.Module, placement_config: dict): - model = GeminiDDP(model, chunk_init_device=get_current_device(), pin_memory=True, **placement_config) + model = GeminiDDP( + model, chunk_init_device=get_accelerator().get_current_device(), pin_memory=True, **placement_config + ) return model @@ -63,7 +65,7 @@ def exam_inference(placement_config: dict, model_name: str, model_init_func: Cal torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) - init_dev = get_current_device() + init_dev = get_accelerator().get_current_device() model = model_builder().to(init_dev) for torch_p, p in zip(torch_model.parameters(), model.parameters()): diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py index 405d7d789..71bb27b4a 100644 --- a/tests/test_zero/test_gemini/test_optim.py +++ b/tests/test_zero/test_gemini/test_optim.py @@ -5,11 +5,11 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close import colossalai +from colossalai.accelerator import get_accelerator from colossalai.legacy.amp import convert_to_apex_amp from colossalai.nn.optimizer import HybridAdam from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import set_seed -from colossalai.utils.device import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.kit.model_zoo import model_zoo, run_fwd_bwd @@ -150,7 +150,7 @@ def exam_tiny_example(placement_config, model_name: str, mixed_precision: torch. model = GeminiDDP( model, - chunk_init_device=get_current_device(), + chunk_init_device=get_accelerator().get_current_device(), search_range_m=1, pin_memory=True, mixed_precision=mixed_precision, diff --git a/tests/test_zero/test_gemini/test_search.py b/tests/test_zero/test_gemini/test_search.py index e99f6d59b..cf3658bf9 100644 --- a/tests/test_zero/test_gemini/test_search.py +++ b/tests/test_zero/test_gemini/test_search.py @@ -2,8 +2,8 @@ import pytest import torch import colossalai +from colossalai.accelerator import get_accelerator from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device from colossalai.zero.gemini.chunk import init_chunk_manager, search_chunk_configuration from tests.kit.model_zoo import model_zoo @@ -34,7 +34,7 @@ def exam_chunk_manager(): sharded_ddp_model = model_builder() chunk_manager = init_chunk_manager( sharded_ddp_model, - get_current_device(), + get_accelerator().get_current_device(), hidden_dim=128, search_range_m=1, min_chunk_size_m=0, diff --git a/tests/test_zero/test_low_level/test_grad_acc.py b/tests/test_zero/test_low_level/test_grad_acc.py index 351ae5f67..11f738615 100644 --- a/tests/test_zero/test_low_level/test_grad_acc.py +++ b/tests/test_zero/test_low_level/test_grad_acc.py @@ -7,9 +7,10 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close import colossalai +from colossalai.accelerator import get_accelerator from colossalai.testing import spawn from colossalai.testing.random import seed_all -from colossalai.utils import conditional_context, get_current_device +from colossalai.utils import conditional_context from colossalai.zero import LowLevelZeroOptimizer @@ -28,7 +29,7 @@ class MlpModel(nn.Module): def exam_zero_1_2_grad_acc(): local_rank = torch.distributed.get_rank() seed_all(2009) - device = get_current_device() + device = get_accelerator().get_current_device() # create model zero1_model = MlpModel().to(device) zero2_model = copy.deepcopy(zero1_model) @@ -71,7 +72,7 @@ def exam_zero_1_2_grad_acc(): def exam_zero_1_grad_acc(sync): local_rank = torch.distributed.get_rank() seed_all(2008) - device = get_current_device() + device = get_accelerator().get_current_device() # create models zero_model = MlpModel()