[Sync] Update from main to feature/colossal-infer (Merge pull request #5685)

[Sync] Update from main to feature/colossal-infer

- Merge pull request #5685 from yuanheng-zhao/inference/merge/main
pull/5695/head
Yuanheng Zhao 2024-05-06 14:43:38 +08:00 committed by GitHub
commit db7b3051f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
378 changed files with 4671 additions and 2425 deletions

View File

@ -56,7 +56,7 @@ jobs:
needs: detect-changed-doc
runs-on: [self-hosted, gpu]
container:
image: hpcaitech/pytorch-cuda:2.0.0-11.7.0
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
options: --gpus all --rm
timeout-minutes: 20
defaults:

View File

@ -24,7 +24,7 @@ jobs:
version=$(cat version.txt)
tag=hpcaitech/colossalai:$version
latest=hpcaitech/colossalai:latest
docker build --build-arg http_proxy=http://172.17.0.1:7890 --build-arg https_proxy=http://172.17.0.1:7890 --build-arg VERSION=v${version} -t $tag ./docker
docker build --build-arg VERSION=v${version} -t $tag ./docker
docker tag $tag $latest
echo "tag=${tag}" >> $GITHUB_OUTPUT
echo "latest=${latest}" >> $GITHUB_OUTPUT

15
LICENSE
View File

@ -552,3 +552,18 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved.
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
---------------- LICENSE FOR Hugging Face accelerate ----------------
Copyright 2021 The HuggingFace Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@ -25,6 +25,8 @@
</div>
## Latest News
* [2024/04] [Open-Sora Unveils Major Upgrade: Embracing Open Source with Single-Shot 16-Second Video Generation and 720p Resolution](https://hpc-ai.com/blog/open-soras-comprehensive-upgrade-unveiled-embracing-16-second-video-generation-and-720p-resolution-in-open-source)
* [2024/04] [Most cost-effective solutions for inference, fine-tuning and pretraining, tailored to LLaMA3 series](https://hpc-ai.com/blog/most-cost-effective-solutions-for-inference-fine-tuning-and-pretraining-tailored-to-llama3-series)
* [2024/03] [314 Billion Parameter Grok-1 Inference Accelerated by 3.8x, Efficient and Easy-to-Use PyTorch+HuggingFace version is Here](https://hpc-ai.com/blog/314-billion-parameter-grok-1-inference-accelerated-by-3.8x-efficient-and-easy-to-use-pytorchhuggingface-version-is-here)
* [2024/03] [Open-Sora: Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models](https://hpc-ai.com/blog/open-sora-v1.0)
* [2024/03] [Open-SoraSora Replication Solution with 46% Cost Reduction, Sequence Expansion to Nearly a Million](https://hpc-ai.com/blog/open-sora)
@ -52,7 +54,7 @@
<li>
<a href="#Parallel-Training-Demo">Parallel Training Demo</a>
<ul>
<li><a href="#LLaMA2">LLaMA 1/2</a></li>
<li><a href="#LLaMA3">LLaMA 1/2/3 </a></li>
<li><a href="#MoE">MoE</a></li>
<li><a href="#GPT-3">GPT-3</a></li>
<li><a href="#GPT-2">GPT-2</a></li>
@ -131,7 +133,7 @@ distributed training and inference in a few lines.
[Open-Sora](https://github.com/hpcaitech/Open-Sora)Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models
[[code]](https://github.com/hpcaitech/Open-Sora)
[[blog]](https://hpc-ai.com/blog/open-sora-v1.0)
[[blog]](https://hpc-ai.com/blog/open-soras-comprehensive-upgrade-unveiled-embracing-16-second-video-generation-and-720p-resolution-in-open-source)
[[HuggingFace model weights]](https://huggingface.co/hpcai-tech/Open-Sora)
[[Demo]](https://github.com/hpcaitech/Open-Sora?tab=readme-ov-file#-latest-demo)
@ -270,13 +272,21 @@ Acceleration of [AlphaFold Protein Structure](https://alphafold.ebi.ac.uk/)
<p align="right">(<a href="#top">back to top</a>)</p>
## Parallel Training Demo
### LLaMA3
<p align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/examples/images/LLaMA3-70B-H100.png" width=600/>
</p>
- 70 billion parameter LLaMA3 model training accelerated by 18%
[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama)
### LLaMA2
<p align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/llama2_pretraining.png" width=600/>
</p>
- 70 billion parameter LLaMA2 model training accelerated by 195%
[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama2)
[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama)
[[blog]](https://www.hpc-ai.tech/blog/70b-llama2-training)
### LLaMA1
@ -285,7 +295,7 @@ Acceleration of [AlphaFold Protein Structure](https://alphafold.ebi.ac.uk/)
</p>
- 65-billion-parameter large model pretraining accelerated by 38%
[[code]](https://github.com/hpcaitech/ColossalAI/tree/example/llama/examples/language/llama)
[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama)
[[blog]](https://www.hpc-ai.tech/blog/large-model-pretraining)
### MoE

View File

@ -1 +0,0 @@
0.0.1

View File

@ -1,6 +1,6 @@
<div align="center">
<h1>
<img src="https://github.com/hpcaitech/public_assets/blob/main/applications/colossal-llama-2/colossalllam2.jpg?raw=true" width=800/>
Colossal-LLaMA
</h1>
</div>
@ -47,6 +47,7 @@
- [Citations](#citations)
## News
* [2024/4] Support continual pre-training and supervised fine-tuning of LLaMA-3.
* [2024/01] [Construct Refined 13B Private Model With Just $5000 USD, Upgraded Colossal-AI Llama-2 Open Source](https://hpc-ai.com/blog/colossal-llama-2-13b).
[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Colossal-LLaMA-2)
[[blog]](https://hpc-ai.com/blog/colossal-llama-2-13b)
@ -289,7 +290,7 @@ Here is details about CLI arguments:
#### 1. Install required packages
```
cd Colossal-LLaMA-2
cd Colossal-LLaMA
pip install -r requirements.txt
```
#### 2. Install `xentropy`, `layer_norm` and `rotary`
@ -314,7 +315,7 @@ Initialize new tokenizer with additional Chinese tokens. Additional Chinese toke
Command to initialize new tokenizer:
```bash
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION='python'
python colossal_llama2/tokenizer/init_tokenizer.py \
python colossal_llama/tokenizer/init_tokenizer.py \
--source_tokenizer_dir "<SOURCE_TOKENIZER_DIR>" \
--target_tokenizer_dir "<TARGET_TOKENIZER_DIR>" \
--expand_tokens_file "<NEW_TOKENS_FILE>.jsonl"
@ -328,7 +329,7 @@ Here is details about CLI arguments:
Initialize the new model checkpoint by calculating the mean values from the original model checkpoint.
Command to initialize new model checkpoint:
```bash
python colossal_llama2/model/init_model.py \
python colossal_llama/model/init_model.py \
--source_model_and_tokenizer_path "<SOURCE_MODEL_AND_TOKENIZER_DIR>" \
--target_tokenizer_path "<TARGET_TOKENIZER_DIR>" \
--target_model_path "<TARGET_MODEL_DIR>"
@ -362,18 +363,17 @@ Command to convert jsonl dataset to arrow format:
python prepare_pretrain_dataset.py \
--data_input_dirs "<JSONL_DIR_1>,<JSONL_DIR_2>,<JSONL_DIR_3>" \
--tokenizer_dir "<TOKENIZER_DIR>" \
--data_cache_dir "jsonl_to_arrow_cache" \
--data_jsonl_output_dir "spliced_tokenized_output_jsonl" \
--data_arrow_output_dir "spliced_tokenized_output_arrow" \
--data_output_dirs "spliced tokenized output" \
--max_length 4096 \
--num_spliced_dataset_bins 10
```
Here is details about CLI arguments:
* Source data directory: `data_input_dirs`. Each `<JSONL_DIR>` can have multiple file in `jsonl` format.
* Tokenizer directory: `tokenizer_dir`. Path to the tokenizer in Hugging Face format.
* Data cache directory: `data_cache_dir`. Directory to store Hugging Face data cache. Default case will create `cache` folder locally.
* Output directory for jsonl format: `data_jsonl_output_dir`. Output directory to store converted dataset in jsonl format.
* Output directory for arrow format: `data_arrow_output_dir`. Output directory to store converted dataset in arrow format, which can be used for training directly.
* Data output directory: `data_output_dirs`. Directory to store preprocessed output, including three sub-directories:
* `cache`: Directory to store Hugging Face data cache.
* `jsonl`: Output directory to store converted dataset in jsonl format.
* `arrow`: Output directory to store converted dataset in arrow format, which can be used for training directly.
* Max length: `max_length`. Max length of spliced samples. Default value is 4096.
* Number of bins for each category: `num_spliced_dataset_bins`. Number of bins for each category, used for bucket-based training.
@ -392,13 +392,15 @@ Command to convert jsonl dataset to arrow format is similar to the command in [3
python prepare_sft_dataset.py.py \
--data_input_dirs "<JSONL_DIR_1>,<JSONL_DIR_2>,<JSONL_DIR_3>" \
--tokenizer_dir "<TOKENIZER_DIR>" \
--data_cache_dir "jsonl_to_arrow_cache" \
--data_jsonl_output_dir "spliced_tokenized_output_jsonl" \
--data_arrow_output_dir "spliced_tokenized_output_arrow" \
--data_output_dirs "spliced tokenized output" \
--max_length 4096 \
--num_spliced_dataset_bins 10
--num_spliced_dataset_bins 10 \
--llama_version 3
```
Additional CLI arguments:
* LLaMA verison: `llama_version`. Specify the LLaMA version.
#### 4. Command Line Arguments for Training
##### 4.1 Arguments for Pretraining

View File

@ -83,7 +83,7 @@ class Conversation:
}
conv = Conversation(
LLaMA2_Conv = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
roles=("Human", "Assistant"),
@ -93,4 +93,14 @@ conv = Conversation(
seps=["<s>", "</s>"],
)
default_conversation = conv
LLaMA3_Conv = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
roles=("Human", "Assistant"),
messages=[],
offset=0,
sep_style=SeparatorStyle.ADD_BOS_EOS_TOKEN,
seps=["<|begin_of_text|>", "<|end_of_text|>"],
)
default_conversation = LLaMA3_Conv

View File

@ -80,15 +80,19 @@ class DataCollatorForSupervisedDataset(object):
# `List[torch.Tensor]`
batch_input_ids = [
torch.LongTensor(instance["input_ids"][: self.max_length])
if len(instance["input_ids"]) > self.max_length
else torch.LongTensor(instance["input_ids"])
(
torch.LongTensor(instance["input_ids"][: self.max_length])
if len(instance["input_ids"]) > self.max_length
else torch.LongTensor(instance["input_ids"])
)
for instance in instances
]
batch_labels = [
torch.LongTensor(instance["labels"][: self.max_length])
if len(instance["labels"]) > self.max_length
else torch.LongTensor(instance["labels"])
(
torch.LongTensor(instance["labels"][: self.max_length])
if len(instance["labels"]) > self.max_length
else torch.LongTensor(instance["labels"])
)
for instance in instances
]

View File

@ -12,6 +12,7 @@ from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
from datasets import dataset_dict
from torch.utils.data import ConcatDataset, Dataset, IterableDataset
from transformers import AutoTokenizer
from transformers.models.llama.tokenization_llama import LlamaTokenizer
from transformers.tokenization_utils import PreTrainedTokenizer
@ -71,7 +72,7 @@ def supervised_tokenize_pretrain(
def supervised_tokenize_sft(
data_point: Dict[str, str],
tokenizer: LlamaTokenizer,
tokenizer: AutoTokenizer,
conversation_template: Conversation = default_conversation,
ignore_index: int = None,
max_length: int = 4096,

View File

@ -1,7 +1,7 @@
import argparse
import torch
from colossal_llama2.dataset.conversation import default_conversation
from colossal_llama.dataset.conversation import default_conversation
from transformers import AutoModelForCausalLM, AutoTokenizer
from colossalai.logging import get_dist_logger

View File

@ -11,12 +11,12 @@ import os
import time
from multiprocessing import cpu_count
from colossal_llama2.dataset.spliced_and_tokenized_dataset import (
from colossal_llama.dataset.spliced_and_tokenized_dataset import (
ClosedToConstantLengthSplicedDataset,
supervised_tokenize_pretrain,
)
from datasets import dataset_dict, load_dataset
from transformers.models.llama.tokenization_llama import LlamaTokenizer
from transformers import AutoTokenizer
from colossalai.logging import get_dist_logger
@ -35,35 +35,24 @@ def main():
parser.add_argument(
"--tokenizer_dir", type=str, required=True, default=None, help="A directory containing the tokenizer"
)
parser.add_argument("--data_cache_dir", type=str, default="cache", help="Data cache directory")
parser.add_argument(
"--data_jsonl_output_dir",
type=str,
default="jsonl_output",
help="Output directory of spliced dataset with jsonl format",
)
parser.add_argument(
"--data_arrow_output_dir",
type=str,
default="arrow_output",
help="Output directory of spliced dataset with arrow format",
)
parser.add_argument("--max_length", type=int, default=4096, help="Max length of each spliced tokenized sequence")
parser.add_argument("--data_output_dirs", type=str, default="data_output_dirs", help="Data output directory")
parser.add_argument("--max_length", type=int, default=8192, help="Max length of each spliced tokenized sequence")
parser.add_argument("--num_spliced_dataset_bins", type=int, default=10, help="Number of spliced dataset bins")
args = parser.parse_args()
if args.num_spliced_dataset_bins >= 100000:
raise ValueError("Too many spliced divisions, must be smaller than 100000")
assert not os.path.exists(args.data_cache_dir), f"Find existed data cache dir {args.data_cache_dir}"
assert not os.path.exists(
args.data_jsonl_output_dir
), f"Find existed jsonl data output dir {args.data_jsonl_output_dir}"
assert not os.path.exists(
args.data_arrow_output_dir
), f"Find existed arrow data output dir {args.data_arrow_output_dir}"
os.makedirs(args.data_jsonl_output_dir)
os.makedirs(args.data_arrow_output_dir)
args.data_cache_dir = os.path.join(args.data_output_dirs, "cache")
args.data_jsonl_output_dir = os.path.join(args.data_output_dirs, "jsonl")
args.data_arrow_output_dir = os.path.join(args.data_output_dirs, "arrow")
if not os.path.exists(args.data_cache_dir):
os.makedirs(args.data_cache_dir)
if not os.path.exists(args.data_jsonl_output_dir):
os.makedirs(args.data_jsonl_output_dir)
if not os.path.exists(args.data_arrow_output_dir):
os.makedirs(args.data_arrow_output_dir)
# Prepare to all input datasets
input_data_paths = []
@ -86,7 +75,7 @@ def main():
train_splits.append(f"train[{start}%:{end}%]")
# Prepare to the tokenizer.
tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_dir)
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir)
tokenizer.add_bos_token = False
tokenizer.add_eos_token = False
if tokenizer.pad_token is None:

View File

@ -10,10 +10,10 @@ import math
import os
from multiprocessing import cpu_count
from colossal_llama2.dataset.conversation import default_conversation
from colossal_llama2.dataset.spliced_and_tokenized_dataset import supervised_tokenize_sft
from colossal_llama.dataset.conversation import default_conversation
from colossal_llama.dataset.spliced_and_tokenized_dataset import supervised_tokenize_sft
from datasets import dataset_dict, load_dataset
from transformers.models.llama.tokenization_llama import LlamaTokenizer
from transformers import AddedToken, AutoTokenizer
from colossalai.logging import get_dist_logger
@ -32,35 +32,25 @@ def main():
parser.add_argument(
"--tokenizer_dir", type=str, required=True, default=None, help="A directory containing the tokenizer"
)
parser.add_argument("--data_cache_dir", type=str, default="cache", help="Data cache directory")
parser.add_argument(
"--data_jsonl_output_dir",
type=str,
default="jsonl_output",
help="Output directory of spliced dataset with jsonl format",
)
parser.add_argument(
"--data_arrow_output_dir",
type=str,
default="arrow_output",
help="Output directory of spliced dataset with arrow format",
)
parser.add_argument("--max_length", type=int, default=4096, help="Max length of each spliced tokenized sequence")
parser.add_argument("--data_output_dirs", type=str, default="data_output_dirs", help="Data output directory")
parser.add_argument("--max_length", type=int, default=8192, help="Max length of each spliced tokenized sequence")
parser.add_argument("--num_spliced_dataset_bins", type=int, default=10, help="Number of spliced dataset bins")
parser.add_argument("--llama_version", type=int, default=3, help="LLaMA version")
args = parser.parse_args()
if args.num_spliced_dataset_bins >= 100000:
raise ValueError("Too many spliced divisions, must be smaller than 100000")
assert not os.path.exists(args.data_cache_dir), f"Find existed data cache dir {args.data_cache_dir}"
assert not os.path.exists(
args.data_jsonl_output_dir
), f"Find existed jsonl data output dir {args.data_jsonl_output_dir}"
assert not os.path.exists(
args.data_arrow_output_dir
), f"Find existed arrow data output dir {args.data_arrow_output_dir}"
os.makedirs(args.data_jsonl_output_dir)
os.makedirs(args.data_arrow_output_dir)
args.data_cache_dir = os.path.join(args.data_output_dirs, "cache")
args.data_jsonl_output_dir = os.path.join(args.data_output_dirs, "jsonl")
args.data_arrow_output_dir = os.path.join(args.data_output_dirs, "arrow")
if not os.path.exists(args.data_cache_dir):
os.makedirs(args.data_cache_dir)
if not os.path.exists(args.data_jsonl_output_dir):
os.makedirs(args.data_jsonl_output_dir)
if not os.path.exists(args.data_arrow_output_dir):
os.makedirs(args.data_arrow_output_dir)
# Prepare to all input datasets
input_data_paths = []
@ -83,11 +73,20 @@ def main():
train_splits.append(f"train[{start}%:{end}%]")
# Prepare to the tokenizer.
tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_dir)
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir)
# Fix </s> split issue: https://github.com/huggingface/transformers/issues/23833
if args.llama_version == 2:
tokenizer.add_tokens(AddedToken("</s>", normalized=False, special=True), special_tokens=True)
tokenizer.add_bos_token = False
tokenizer.add_eos_token = False
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.unk_token
if tokenizer.unk_token is not None:
tokenizer.pad_token = tokenizer.unk_token
else:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.unk_token = tokenizer.eos_token
list_dataset = load_dataset(
path="json",

View File

@ -1,9 +1,10 @@
torch<2.0.0, >=1.12.1
packaging==23.1
colossalai==0.3.5
torch==2.1.2
huggingface-hub
packaging==24.0
colossalai==0.3.6
autoflake==2.2.1
black==23.9.1
transformers==4.33.3
transformers==4.34.1
tensorboard==2.14.0
six==1.16.0
datasets

View File

@ -1,6 +1,6 @@
import argparse
from colossal_llama2.utils.stream_chat_patch import streaming_chat
from colossal_llama.utils.stream_chat_patch import streaming_chat
from transformers import AutoModelForCausalLM, AutoTokenizer
SYSTEM = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."

View File

@ -12,18 +12,18 @@ from contextlib import nullcontext
import torch
import torch.distributed as dist
from colossal_llama2.dataset.loader import (
from colossal_llama.dataset.loader import (
DataCollatorForSupervisedDataset,
StatefulDistributedSampler,
load_tokenized_dataset,
)
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
from colossal_llama2.utils.froze import freeze_non_embeds_parameters
from colossal_llama2.utils.neftune_patch import activate_neftune, deactivate_neftune
from colossal_llama.utils.ckpt_io import load_checkpoint, save_checkpoint
from colossal_llama.utils.flash_attention_patch import replace_with_flash_attention
from colossal_llama.utils.froze import freeze_non_embeds_parameters
from colossal_llama.utils.neftune_patch import activate_neftune, deactivate_neftune
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from transformers import LlamaForCausalLM, LlamaTokenizer
from transformers import AutoTokenizer, LlamaForCausalLM
import colossalai
from colossalai.accelerator import get_accelerator
@ -89,7 +89,7 @@ def main() -> None:
parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of accumulation steps")
parser.add_argument("--micro_batch_size", type=int, default=2, help="Batch size of each process")
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
parser.add_argument("--max_length", type=int, default=4096, help="Model max length")
parser.add_argument("--max_length", type=int, default=8192, help="Model max length")
parser.add_argument(
"--mixed_precision",
type=str,
@ -136,7 +136,7 @@ def main() -> None:
# ==============================
# Initialize Distributed Training
# ==============================
colossalai.launch_from_torch({})
colossalai.launch_from_torch()
accelerator = get_accelerator()
coordinator = DistCoordinator()
@ -196,7 +196,7 @@ def main() -> None:
# ======================================================
# Initialize Tokenizer, Dataset, Collator and Dataloader
# ======================================================
tokenizer = LlamaTokenizer.from_pretrained(args.pretrained)
tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
if args.pad_token == "eos":
tokenizer.pad_token = tokenizer.eos_token
elif args.pad_token == "unk":
@ -253,9 +253,11 @@ def main() -> None:
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
optimizer = HybridAdam(
model_params=filter(lambda p: p.requires_grad, model.parameters())
if args.freeze_non_embeds_params
else model.parameters(),
model_params=(
filter(lambda p: p.requires_grad, model.parameters())
if args.freeze_non_embeds_params
else model.parameters()
),
lr=args.lr,
betas=(0.9, 0.95),
weight_decay=args.weight_decay,

View File

@ -0,0 +1 @@
1.0.0

View File

@ -66,7 +66,7 @@ def benchmark_train(args):
# ==============================
# Initialize Distributed Training
# ==============================
colossalai.launch_from_torch({})
colossalai.launch_from_torch()
coordinator = DistCoordinator()
# ======================================================

View File

@ -37,7 +37,7 @@ def train(args):
# ==============================
# Initialize Distributed Training
# ==============================
colossalai.launch_from_torch({})
colossalai.launch_from_torch()
coordinator = DistCoordinator()
# ==============================

View File

@ -39,7 +39,7 @@ def train(args):
# ==============================
# Initialize Distributed Training
# ==============================
colossalai.launch_from_torch({})
colossalai.launch_from_torch()
coordinator = DistCoordinator()
# ======================================================

View File

@ -34,7 +34,7 @@ def train(args):
# ==============================
# Initialize Distributed Training
# ==============================
colossalai.launch_from_torch({})
colossalai.launch_from_torch()
coordinator = DistCoordinator()
# ======================================================

View File

@ -29,7 +29,7 @@ def train(args):
# ==============================
# Initialize Distributed Training
# ==============================
colossalai.launch_from_torch({})
colossalai.launch_from_torch()
coordinator = DistCoordinator()
# ==============================

View File

@ -81,7 +81,7 @@ def rm_and_merge(
def main(args):
colossalai.launch_from_torch(config={}, seed=42)
colossalai.launch_from_torch(seed=42)
accelerator = get_accelerator()
world_size = dist.get_world_size()

View File

@ -81,7 +81,7 @@ def rm_and_merge(
def main(args):
colossalai.launch_from_torch(config={}, seed=42)
colossalai.launch_from_torch(seed=42)
world_size = dist.get_world_size()
rank = dist.get_rank()

Binary file not shown.

View File

@ -57,7 +57,7 @@ def main():
args = parse_args()
# Launch ColossalAI
colossalai.launch_from_torch(config={}, seed=args.seed)
colossalai.launch_from_torch(seed=args.seed)
coordinator = DistCoordinator()
config = MixtralConfig.from_pretrained(args.model_name)
@ -96,7 +96,11 @@ def main():
if coordinator.rank == 0:
text = ["Hello my name is"]
else:
text = ["What's the largest country in the world?", "How many people live in China?", "帮我续写这首诗:离离原上草"]
text = [
"What's the largest country in the world?",
"How many people live in China?",
"帮我续写这首诗:离离原上草",
]
tokenizer.pad_token = tokenizer.unk_token
inputs = tokenizer(text, return_tensors="pt", padding=True).to(torch.cuda.current_device())

View File

@ -50,7 +50,7 @@ def check_mixtral_moe_layer():
def run_dist(rank: int, world_size: int, port: int):
colossalai.launch({}, rank, world_size, "localhost", port)
colossalai.launch(rank, world_size, "localhost", port)
check_mixtral_moe_layer()

View File

@ -133,7 +133,7 @@ def check_mixtral_moe_layer():
def run_dist(rank: int, world_size: int, port: int):
colossalai.launch({}, rank, world_size, "localhost", port)
colossalai.launch(rank, world_size, "localhost", port)
check_mixtral_moe_layer()

View File

@ -145,7 +145,7 @@ def main():
args = parse_args()
# Launch ColossalAI
colossalai.launch_from_torch(config={}, seed=args.seed)
colossalai.launch_from_torch(seed=args.seed)
coordinator = DistCoordinator()
# Set plugin
@ -195,9 +195,9 @@ def main():
lr_scheduler = CosineAnnealingWarmupLR(
optimizer=optimizer,
total_steps=args.num_epochs * len(dataloader),
warmup_steps=args.warmup_steps
if args.warmup_steps is not None
else int(args.num_epochs * len(dataloader) * 0.025),
warmup_steps=(
args.warmup_steps if args.warmup_steps is not None else int(args.num_epochs * len(dataloader) * 0.025)
),
eta_min=0.1 * args.lr,
)

View File

@ -5,7 +5,7 @@ This directory contains the applications that are powered by Colossal-AI.
The list of applications include:
- [X] [Open-Sora](https://github.com/hpcaitech/Open-Sora): Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models
- [X] [Colossal-LLaMA-2](./Colossal-LLaMA-2/): Continual Pre-training of LLaMA-2.
- [X] [Colossal-LLaMA](./Colossal-LLaMA/): Continual Pre-training and Supervisied Fine-tuning of LLaMA2 / LLaMA3.
- [X] [ColossalEval](./ColossalEval): Evaluation Pipeline for LLMs.
- [X] [ColossalChat](./Chat/README.md): Replication of ChatGPT with RLHF.
- [X] [FastFold](https://github.com/hpcaitech/FastFold): Optimizing AlphaFold (Biomedicine) Training and Inference on GPU Clusters.

View File

@ -246,7 +246,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
@compatibility(is_backward_compatible=True)
class ActivationCheckpointCodeGen(CodeGen):
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode:
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace, verbose=None) -> PythonCode:
free_vars: List[str] = []
body: List[str] = []
globals_: Dict[str, Any] = {}

View File

@ -126,7 +126,7 @@ class AMPOptimizer(OptimizerWrapper):
return self.grad_scaler.scale.item()
def zero_grad(self, *args, **kwargs):
self.module.overflow_counter = torch.cuda.IntTensor([0])
self.module.overflow_counter = torch.tensor([0], dtype=torch.int, device=get_accelerator().get_current_device())
return self.optim.zero_grad(set_to_none=True)
def step(self, *args, **kwargs):

View File

@ -4,8 +4,8 @@ from typing import Optional, Set
import torch
import torch.nn as nn
from colossalai.utils import _cast_float
from colossalai.zero.legacy.gemini.tensor_utils import free_storage
from colossalai.utils import _cast_float, get_current_device
from colossalai.utils.common import free_storage
from .region_manager import RegionManager
from .util import GlobalRuntimeInfo
@ -25,7 +25,7 @@ class BaseOffloadModule:
self.model = model
self.region_manager = region_manager
self.grad_hook_list = []
self.overflow_counter = torch.cuda.IntTensor([0])
self.overflow_counter = torch.tensor([0], dtype=torch.int, device=get_current_device())
self.grad_offload_stream = torch.cuda.current_stream() if is_sync else GlobalRuntimeInfo.d2h_stream

View File

@ -3,7 +3,8 @@ from typing import Dict, List, Tuple
import torch
from torch.fx import Node
from colossalai.zero.legacy.gemini.tensor_utils import alloc_storage, free_storage
from colossalai.utils.common import free_storage
from colossalai.zero.gemini.chunk.chunk import alloc_storage
class Region:

View File

@ -372,7 +372,7 @@ if AUTOCHUNK_AVAILABLE:
if print_progress:
get_logger().info("AutoChunk start codegen")
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode:
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace, verbose=None) -> PythonCode:
free_vars: List[str] = []
body: List[str] = []
globals_: Dict[str, Any] = {}

View File

@ -8,9 +8,18 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
SUPPORT_PEFT = False
try:
import peft
SUPPORT_PEFT = True
except ImportError:
pass
import colossalai.interface.pretrained as pretrained_utils
from colossalai.checkpoint_io import GeneralCheckpointIO
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.quantization import BnbQuantizationConfig
from .accelerator import Accelerator
from .mixed_precision import MixedPrecision, mixed_precision_factory
@ -221,6 +230,56 @@ class Booster:
assert self.plugin.support_no_sync(), f"The plugin {self.plugin.__class__.__name__} does not support no_sync."
return self.plugin.no_sync(model, optimizer)
def enable_lora(
self,
model: nn.Module,
pretrained_dir: Optional[str] = None,
lora_config: "peft.LoraConfig" = None,
bnb_quantization_config: Optional[BnbQuantizationConfig] = None,
quantize=False,
) -> nn.Module:
"""
Wrap the passed in model with LoRA modules for training. If pretrained directory is provided, lora configs and weights are loaded from that directory.
Lora in ColossalAI is implemented using Huggingface peft library, so the arguments for Lora configuration are same as those of peft.
Args:
model (nn.Module): The model to be appended with LoRA modules.
pretrained_dir(str, optional): The path to the pretrained directory, can be a local directory
or model_id of a PEFT configuration hosted inside a model repo on the Hugging Face Hub.
When set to None, create new lora configs and weights for the model using the passed in lora_config. Defaults to None.
lora_config: (peft.LoraConfig, optional): Passed in LoraConfig for peft. Defaults to None.
"""
if not SUPPORT_PEFT:
raise ImportError("Please install Huggingface Peft library to enable lora features in ColossalAI!")
assert self.plugin is not None, f"Lora can only be enabled when a plugin is provided."
assert self.plugin.support_lora(), f"The plugin {self.plugin.__class__.__name__} does not support lora."
if pretrained_dir is None:
assert (
lora_config is not None
), "Please provide configuration for Lora when pretrained directory path isn't passed in."
assert isinstance(
lora_config, peft.LoraConfig
), "The passed in configuration should be an instance of peft.LoraConfig."
if lora_config is None:
assert (
pretrained_dir is not None
), "Please provide pretrained directory path if not passing in lora configuration."
if quantize is True:
if bnb_quantization_config is not None:
warnings.warn(
"User defined BnbQuantizationConfig is not fully tested in ColossalAI. Use it at your own risk."
)
else:
bnb_quantization_config = BnbQuantizationConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
return self.plugin.enable_lora(model, pretrained_dir, lora_config, bnb_quantization_config)
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None:
"""Load model from checkpoint.
@ -323,3 +382,20 @@ class Booster:
checkpoint (str): Path to the checkpoint. It must be a local file path.
"""
self.checkpoint_io.load_lr_scheduler(lr_scheduler, checkpoint)
def save_lora_as_pretrained(
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
) -> None:
"""
Save the lora adapters and adapter configuration file to a pretrained checkpoint directory.
Args:
model (Union[nn.Module, ModelWrapper]): A model boosted by Booster.
checkpoint (str): Path to the checkpoint directory. It must be a local path.
use_safetensors (bool, optional): Whether to use safe tensors when saving. Defaults to False.
"""
if not SUPPORT_PEFT:
raise ImportError("Please install Huggingface Peft library to enable lora features in ColossalAI!")
assert self.plugin is not None, f"Lora can only be enabled when a plugin is provided."
assert self.plugin.support_lora(), f"The plugin {self.plugin.__class__.__name__} does not support lora."
self.checkpoint_io.save_lora_as_pretrained(model, checkpoint, use_safetensors)

View File

@ -3,7 +3,7 @@ import logging
import os
import random
from pathlib import Path
from typing import Callable, Iterator, List, Optional, Tuple
from typing import Callable, Dict, Iterator, List, Optional, Tuple
import numpy as np
import torch
@ -44,10 +44,10 @@ ZERO_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2
def get_param_info(optim: Optimizer):
# Get a backup of necessary information of parameters for future use, which includes:
# 1. A mapping from integer param_id to param32 shape.
if optim is None:
return {}
param_info = {"id2shape": {}}
start_index = 0
for group in optim.param_groups:
for param_id, param in enumerate(group["params"], start_index):
@ -424,6 +424,7 @@ class GeminiPlugin(DPPluginBase):
)
self.extra_dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) if self.extra_dp_size > 1 else None
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) if self.tp_size > 1 else None
self.dp_size = self.zero_size * self.extra_dp_size
self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
@ -443,6 +444,9 @@ class GeminiPlugin(DPPluginBase):
def support_no_sync(self) -> bool:
return False
def support_lora(self) -> bool:
return False
def control_precision(self) -> bool:
return True
@ -527,7 +531,7 @@ class GeminiPlugin(DPPluginBase):
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
optimizer_params_info = get_param_info(optimizer)
params_info = get_param_info(optimizer)
if not isinstance(model, ModelWrapper):
# convert model to sync bn
# FIXME(ver217): gemini does not support sync bn
@ -558,7 +562,7 @@ class GeminiPlugin(DPPluginBase):
**self.zero_optim_config,
**self.optim_kwargs,
tp_group=self.tp_group,
optimizer_params_info=optimizer_params_info,
params_info=params_info,
verbose=self.verbose,
)
@ -572,3 +576,8 @@ class GeminiPlugin(DPPluginBase):
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
raise NotImplementedError
def enable_lora(
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
) -> nn.Module:
raise NotImplementedError

View File

@ -4,7 +4,7 @@ import warnings
from contextlib import contextmanager
from functools import partial
from types import MethodType
from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union
from typing import Any, Callable, Dict, Iterator, List, Optional, OrderedDict, Tuple, Union
import numpy as np
import torch
@ -34,7 +34,6 @@ from colossalai.zero.low_level import LowLevelZeroOptimizer
from .pp_plugin_base import PipelinePluginBase
DP_AXIS, PP_AXIS, TP_AXIS, SP_AXIS = 0, 1, 2, 3
SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"]
PRECISION_TORCH_TYPE = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}
@ -213,12 +212,7 @@ def get_param_info(optim: Optimizer):
if optim is None:
return {}
param_info = {
"param_groups": [],
"param2id": {},
"id2param": {},
"param2shape": {},
}
param_info = {"param_groups": [], "param2id": {}, "id2param": {}, "param2shape": {}}
start_index = 0
for group in optim.param_groups:
packed_group = {k: v for k, v in group.items() if k != "params"}
@ -947,6 +941,8 @@ class HybridParallelPlugin(PipelinePluginBase):
num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1.
gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None.
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
"""
def __init__(
@ -987,8 +983,11 @@ class HybridParallelPlugin(PipelinePluginBase):
custom_policy: Policy = None,
pp_style: str = "1f1b",
num_model_chunks: int = 1,
num_layers_per_stage: Optional[List[int]] = None,
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
enable_metadata_cache: bool = True,
make_vocab_size_divisible_by: int = 64,
dp_outside: bool = True,
) -> None:
super().__init__()
assert (
@ -1036,7 +1035,12 @@ class HybridParallelPlugin(PipelinePluginBase):
self.enable_flash_attention = enable_flash_attention
self.enable_jit_fused = enable_jit_fused
self.enable_sequence_parallelism = enable_sequence_parallelism
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
if dp_outside:
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
else:
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
self.stage_manager = None
self.schedule = None
self.custom_policy = custom_policy
@ -1050,9 +1054,10 @@ class HybridParallelPlugin(PipelinePluginBase):
assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism"
self.stage_manager = PipelineStageManager(
self.pg_mesh,
pipeline_axis=PP_AXIS,
pipeline_axis=self.pp_axis,
enable_interleave=pp_style == "interleaved",
num_model_chunks=num_model_chunks,
num_layers_per_stage=num_layers_per_stage,
)
if pp_style == "interleaved":
@ -1074,13 +1079,13 @@ class HybridParallelPlugin(PipelinePluginBase):
else:
raise NotImplementedError()
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis)
self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis)
if self.enable_sequence_parallelism and self.sequence_parallelism_mode in ["split_gather", "ring"]:
self.sp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
else:
self.sp_group = self.pg_mesh.get_group_along_axis(SP_AXIS)
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
@ -1095,6 +1100,7 @@ class HybridParallelPlugin(PipelinePluginBase):
sequence_parallelism_mode=sequence_parallelism_mode,
enable_sequence_overlap=enable_sequence_overlap,
parallel_output=parallel_output,
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
gradient_checkpoint_config=gradient_checkpoint_config,
)
self.amp_config = dict(
@ -1150,6 +1156,9 @@ class HybridParallelPlugin(PipelinePluginBase):
def support_no_sync(self) -> bool:
return True
def support_lora(self) -> bool:
return False
def control_checkpoint_io(self) -> bool:
return True
@ -1170,7 +1179,7 @@ class HybridParallelPlugin(PipelinePluginBase):
and self.sequence_parallelism_mode == "all_to_all"
)
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
dp_group = self.pg_mesh.create_group_along_axis([DP_AXIS, SP_AXIS])
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
else:
dp_group = self.dp_group
model = HybridParallelModule(
@ -1318,7 +1327,10 @@ class HybridParallelPlugin(PipelinePluginBase):
_kwargs = kwargs.copy()
distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
sampler = distributed_sampler_cls(
dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle
dataset,
num_replicas=self.pg_mesh.size(self.dp_axis),
rank=self.pg_mesh.coordinate(self.dp_axis),
shuffle=shuffle,
)
# Deterministic dataloader
@ -1347,3 +1359,8 @@ class HybridParallelPlugin(PipelinePluginBase):
self.zero_stage != 2
), "ZERO2 is not compatible with no_sync function, please run gradient accumulation with gradient synchronization allowed."
return optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
def enable_lora(
self, model: Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
) -> Module:
raise NotImplementedError

View File

@ -1,12 +1,15 @@
import enum
import logging
import os
import warnings
from functools import partial
from pathlib import Path
from types import MethodType
from typing import Callable, Iterator, List, Optional, Tuple
from typing import Callable, Dict, Iterator, List, Optional, Tuple
import torch
import torch.nn as nn
from torch.nn import Parameter
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map
@ -25,6 +28,7 @@ from colossalai.checkpoint_io.utils import (
sharded_optimizer_loading_epilogue,
)
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.zero import LowLevelZeroOptimizer
from .dp_plugin_base import DPPluginBase
@ -42,6 +46,12 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
SUPPORTED_PRECISION = ["fp16", "bf16", "fp32"]
class OptimizerParamCheckState(enum.Enum):
ORIGIN_PARAM_FINDED = 0
ORIGIN_PARAM_NOT_FIND = -1
LORA_PARM_EXISTED = -2
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
def __init__(self, module: nn.Module, precision: str) -> None:
super().__init__(module)
@ -209,6 +219,19 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
model.update_master_params()
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
from peft import PeftModel
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
peft_model = model.unwrap()
assert isinstance(
peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving."
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors)
class LowLevelZeroPlugin(DPPluginBase):
"""
@ -288,6 +311,7 @@ class LowLevelZeroPlugin(DPPluginBase):
cpu_offload=cpu_offload,
master_weights=master_weights,
)
self.lora_enabled = False
self.verbose = verbose
# set class name with stage, for better error message
@ -296,6 +320,9 @@ class LowLevelZeroPlugin(DPPluginBase):
def support_no_sync(self) -> bool:
return self.stage == 1
def support_lora(self) -> bool:
return False
def control_precision(self) -> bool:
return True
@ -308,6 +335,79 @@ class LowLevelZeroPlugin(DPPluginBase):
def supported_devices(self) -> List[str]:
return ["cuda", "npu"]
def support_lora(self) -> bool:
return True
def enable_lora(
self,
model: nn.Module,
pretrained_dir: Optional[str] = None,
lora_config: Optional[Dict] = None,
bnb_quantization_config: Optional[BnbQuantizationConfig] = None,
) -> nn.Module:
from peft import PeftModel, get_peft_model
assert not isinstance(model, LowLevelZeroModel), "Lora should be enabled before boosting the model."
self.lora_enabled = True
warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr")
if bnb_quantization_config is not None:
model = quantize_model(model, bnb_quantization_config)
if pretrained_dir is None:
peft_model = get_peft_model(model, lora_config)
else:
peft_model = PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True)
return peft_model
def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter):
origin_param_id = id(origin_param)
for group_id, param_group in enumerate(optimizer.param_groups):
for p in param_group["params"]:
if id(p) == origin_param_id:
return group_id
return -1
def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter, lora_param: Parameter):
origin_param_id = id(origin_param)
lora_param_id = id(lora_param)
target_group_id = None
for group_id, param_group in enumerate(optimizer.param_groups):
for p in param_group["params"]:
if id(p) == lora_param_id:
# check if the lora parameter exists.
return target_group_id, OptimizerParamCheckState.LORA_PARM_EXISTED
if id(p) == origin_param_id:
target_group_id = group_id
if target_group_id is not None:
return target_group_id, OptimizerParamCheckState.ORIGIN_PARAM_FINDED
else:
return target_group_id, OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND
def add_lora_params_to_optimizer(self, model, optimizer):
"""add lora parameters to optimizer"""
name2param = {}
for name, param in model.named_parameters():
name2param[name] = param
for name, param in name2param.items():
if "lora_A" in name or "lora_B" in name:
origin_key = name.replace("lora_A.", "")
origin_key = origin_key.replace("lora_B.", "")
origin_key = origin_key.replace(f"{model.active_adapter}", "base_layer")
origin_param = name2param[origin_key]
group_id, check_state = self.get_param_group_id(optimizer, origin_param, param)
if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND:
warnings.warn(
"Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups."
)
elif (
check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED
and group_id is not None
and group_id >= 0
):
optimizer.param_groups[group_id]["params"].append(param)
def configure(
self,
model: nn.Module,
@ -316,6 +416,15 @@ class LowLevelZeroPlugin(DPPluginBase):
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
if self.lora_enabled:
from peft import PeftModel
assert isinstance(
model, PeftModel
), "The model should have been wrapped as a PeftModel when self.lora_enabled is True"
if optimizer is not None:
self.add_lora_params_to_optimizer(model, optimizer)
if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(model, self.precision)

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Callable, Iterator, List, Optional, Tuple
from typing import Callable, Dict, Iterator, List, Optional, Tuple
import torch.nn as nn
from torch.optim import Optimizer
@ -33,6 +33,10 @@ class Plugin(ABC):
def support_no_sync(self) -> bool:
pass
@abstractmethod
def support_lora(self) -> bool:
pass
@abstractmethod
def configure(
self,
@ -63,6 +67,12 @@ class Plugin(ABC):
Context manager to disable gradient synchronization.
"""
@abstractmethod
def enable_lora(self, model: nn.Module, pretrained_dir: str, lora_config: Dict) -> nn.Module:
"""
Add LoRA modules to the model passed in. Should only be called in booster.enable_lora().
"""
@abstractmethod
def prepare_dataloader(
self,

View File

@ -1,4 +1,4 @@
from typing import Callable, Iterator, List, Optional, Tuple
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
@ -9,6 +9,8 @@ from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.utils import get_current_device
from .dp_plugin_base import DPPluginBase
@ -116,6 +118,22 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
super().load_sharded_optimizer(optimizer.unwrap(), index_file_path, prefix)
def save_lora_as_pretrained(
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
) -> None:
"""
Save the lora adapters and adapter configuration file to checkpoint directory.
"""
from peft import PeftModel
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
if self.coordinator.is_master():
peft_model = model.unwrap()
assert isinstance(
peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving."
peft_model.save_pretrained(save_directory=checkpoint, safe_serialization=use_safetensors)
class TorchDDPModel(ModelWrapper):
def __init__(self, module: nn.Module, *args, **kwargs) -> None:
@ -173,6 +191,9 @@ class TorchDDPPlugin(DPPluginBase):
def support_no_sync(self) -> bool:
return True
def support_lora(self) -> bool:
return True
def control_precision(self) -> bool:
return False
@ -183,7 +204,7 @@ class TorchDDPPlugin(DPPluginBase):
return True
def supported_devices(self) -> List[str]:
return ["cuda"]
return ["cuda", "npu"]
def configure(
self,
@ -194,7 +215,7 @@ class TorchDDPPlugin(DPPluginBase):
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
# cast model to cuda
model = model.cuda()
model = model.to(get_current_device())
# convert model to sync bn
model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None)
@ -216,3 +237,21 @@ class TorchDDPPlugin(DPPluginBase):
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
assert isinstance(model, TorchDDPModel), "Model is not boosted by TorchDDPPlugin."
return model.module.no_sync()
def enable_lora(
self,
model: nn.Module,
pretrained_dir: Optional[str] = None,
lora_config: Optional[Dict] = None,
bnb_quantization_config: Optional[BnbQuantizationConfig] = None,
) -> nn.Module:
from peft import PeftModel, get_peft_model
if bnb_quantization_config is not None:
model = quantize_model(model, bnb_quantization_config)
assert not isinstance(model, TorchDDPModel), "Lora should be enabled before boosting the model."
if pretrained_dir is None:
return get_peft_model(model, lora_config)
else:
return PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True)

View File

@ -2,7 +2,7 @@ import logging
import os
import warnings
from pathlib import Path
from typing import Callable, Iterable, Iterator, List, Optional, Tuple
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple
import torch
import torch.nn as nn
@ -318,6 +318,9 @@ class TorchFSDPPlugin(DPPluginBase):
def support_no_sync(self) -> bool:
return False
def support_lora(self) -> bool:
return False
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
raise NotImplementedError("Torch fsdp no_sync func not supported yet.")
@ -361,3 +364,8 @@ class TorchFSDPPlugin(DPPluginBase):
def get_checkpoint_io(self) -> CheckpointIO:
return TorchFSDPCheckpointIO()
def enable_lora(
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
) -> nn.Module:
raise NotImplementedError

View File

@ -335,3 +335,20 @@ class CheckpointIO(ABC):
"""
state_dict = torch.load(checkpoint)
lr_scheduler.load_state_dict(state_dict)
# ================================================================================
# Abstract method for lora saving implementation.
# ================================================================================
@abstractmethod
def save_lora_as_pretrained(
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
) -> None:
"""
Save the lora adapters and adapter configuration file to a pretrained checkpoint directory.
Args:
model (Union[nn.Module, ModelWrapper]): A model boosted by Booster.
checkpoint (str): Path to the checkpoint directory. It must be a local path.
use_safetensors (bool, optional): Whether to use safe tensors when saving. Defaults to False.
"""

View File

@ -228,3 +228,6 @@ class GeneralCheckpointIO(CheckpointIO):
self.__class__.__name__, "\n\t".join(error_msgs)
)
)
def save_lora_as_pretrained(self, model: nn.Module, checkpoint: str, use_safetensors: bool = False) -> None:
raise NotImplementedError

View File

@ -14,6 +14,12 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.tensor.padded_tensor import (
init_as_padded_tensor,
is_padded_tensor,
to_padded_tensor,
to_unpadded_tensor,
)
from colossalai.utils import get_current_device
from .general_checkpoint_io import GeneralCheckpointIO
@ -32,6 +38,7 @@ from .utils import (
save_param_groups,
save_state_dict,
save_state_dict_shards,
search_padding_dim,
search_tp_partition_dim,
sharded_optimizer_loading_epilogue,
)
@ -89,6 +96,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
if param is None:
continue
# Gather tensor pieces when using tensor parallel.
if is_padded_tensor(param):
param = to_unpadded_tensor(param)
param_ = gather_distributed_param(param, keep_vars=False)
block, block_size = state_dict_sharder.append_param(prefix + name, param_)
if block is not None:
@ -231,7 +240,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# When pipeline is used, each stage produces its own shard files and index files.
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
# After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
final_index_file_path = copy.deepcopy(save_index_file)
tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
@ -251,6 +259,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
use_safetensors=use_safetensors,
use_pp_format=True,
)
if control_saving:
assert (
self.dp_rank == 0 and self.tp_rank == 0
@ -867,6 +876,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
dist.all_gather(gather_tensor, v, group=tp_group)
v = torch.cat(gather_tensor, dim=partition_dim)
padding_dim = search_padding_dim(v.shape, original_shape)
if padding_dim is not None:
v = init_as_padded_tensor(v, v.shape[padding_dim], original_shape[padding_dim], padding_dim)
v = to_unpadded_tensor(v)
state_[k] = v.detach().clone().to(device)
return state_
@ -899,6 +913,19 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
if isinstance(v, torch.Tensor) and k != "step":
# Shard state along tensor parallel group.
partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size)
global_shape = current_shape
if partition_dim is not None:
# pad embedding params
global_shape = (
*current_shape[:partition_dim],
current_shape[partition_dim] * self.tp_size,
*current_shape[partition_dim + 1 :],
)
padding_dim = search_padding_dim(global_shape, original_shape)
if padding_dim is not None:
v = to_padded_tensor(v, global_shape[padding_dim], padding_dim)
if partition_dim is not None:
slice_size = current_shape[partition_dim]
v = v.split(slice_size, dim=partition_dim)[self.tp_rank]

View File

@ -120,6 +120,15 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz
return partition_dim
def search_padding_dim(global_shape: torch.Size, original_shape: torch.Size) -> Optional[int]:
padding_dim = None
for dim, length in enumerate(global_shape):
if length > original_shape[dim]:
padding_dim = dim
break
return padding_dim
# ======================================
# Helper classes and functions for saving shard file
# ======================================

View File

@ -1,22 +1,27 @@
import threading
class SingletonMeta(type):
"""
The Singleton class can be implemented in different ways in Python. Some
possible methods include: base class, decorator, metaclass. We will use the
metaclass because it is best suited for this purpose.
Thread-safe Singleton Meta with double-checked locking.
Reference: https://en.wikipedia.org/wiki/Double-checked_locking
"""
_instances = {}
_lock = threading.Lock()
def __call__(cls, *args, **kwargs):
"""
Possible changes to the value of the `__init__` argument do not affect
the returned instance.
"""
# First check (without locking) for performance reasons
if cls not in cls._instances:
instance = super().__call__(*args, **kwargs)
cls._instances[cls] = instance
# Acquire a lock before proceeding to the second check
with cls._lock:
# Second check with lock held to ensure thread safety
if cls not in cls._instances:
instance = super().__call__(*args, **kwargs)
cls._instances[cls] = instance
else:
assert (
len(args) == 0 and len(kwargs) == 0
), f"{cls.__name__} is a singleton class and a instance has been created."
), f"{cls.__name__} is a singleton class and an instance has been created."
return cls._instances[cls]

View File

@ -625,7 +625,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
if CODEGEN_AVAILABLE:
class ActivationCheckpointCodeGen(CodeGen):
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode:
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace, verbose=None) -> PythonCode:
free_vars: List[str] = []
body: List[str] = []
globals_: Dict[str, Any] = {}

View File

@ -270,7 +270,7 @@ def llama_rmsnorm_forward(
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output, residual)
class NopadLlamaMLP(ParallelModule, LlamaMLP):
class NopadLlamaMLP(LlamaMLP, ParallelModule):
def __init__(
self,
config: LlamaConfig,
@ -392,7 +392,7 @@ class NopadLlamaMLP(ParallelModule, LlamaMLP):
return f"gate_up_proj MergedLinear1D_Col: in_features={self.gate_up_weight.shape[1]}x2, out_features={self.gate_up_weight.shape[2]}, bias=False"
class NopadLlamaAttention(ParallelModule, LlamaAttention):
class NopadLlamaAttention(LlamaAttention, ParallelModule):
def __init__(
self,
config: LlamaConfig,

View File

@ -2,20 +2,15 @@
# -*- encoding: utf-8 -*-
import os
import warnings
from pathlib import Path
from typing import Dict, Union
import torch.distributed as dist
from colossalai.accelerator import get_accelerator
from colossalai.context import Config
from colossalai.logging import get_dist_logger
from colossalai.utils import set_seed
def launch(
config: Union[str, Path, Config, Dict],
rank: int,
world_size: int,
host: str,
@ -44,8 +39,6 @@ def launch(
Raises:
Exception: Raise exception when config type is wrong
"""
if rank == 0:
warnings.warn("`config` is deprecated and will be removed soon.")
cur_accelerator = get_accelerator()
@ -68,7 +61,6 @@ def launch(
def launch_from_slurm(
config: Union[str, Path, Config, Dict],
host: str,
port: int,
backend: str = "nccl",
@ -95,7 +87,6 @@ def launch_from_slurm(
)
launch(
config=config,
rank=rank,
world_size=world_size,
host=host,
@ -107,7 +98,6 @@ def launch_from_slurm(
def launch_from_openmpi(
config: Union[str, Path, Config, Dict],
host: str,
port: int,
backend: str = "nccl",
@ -135,7 +125,6 @@ def launch_from_openmpi(
)
launch(
config=config,
local_rank=local_rank,
rank=rank,
world_size=world_size,
@ -147,9 +136,7 @@ def launch_from_openmpi(
)
def launch_from_torch(
config: Union[str, Path, Config, Dict], backend: str = "nccl", seed: int = 1024, verbose: bool = True
):
def launch_from_torch(backend: str = "nccl", seed: int = 1024, verbose: bool = True):
"""A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size
from the environment variables set by PyTorch
@ -171,7 +158,6 @@ def launch_from_torch(
)
launch(
config=config,
local_rank=local_rank,
rank=rank,
world_size=world_size,

View File

@ -119,10 +119,6 @@ class FlashAttentionLoader(KernelLoader):
]
class FlashAttentionWithPaddingMaskLoader(KernelLoader):
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension]
class FlashAttentionWithCustomMaskLoader(KernelLoader):
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension]

View File

@ -56,7 +56,7 @@ class Worker:
# initialize a ray collective group, otherwise colossalai distributed env won't be built successfully
collective.init_collective_group(world_size, rank, "nccl", "default")
# initialize and set distributed environment
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
ray_serve_logger.info(f"Worker with rank {rank} (world size {world_size}) setting up..")
log_cuda_info("Worker.setup")

View File

@ -42,7 +42,7 @@ class CaiInferEngine:
import colossalai
from transformers import LlamaForCausalLM, LlamaTokenizer
colossalai.launch_from_torch(config={})
colossalai.launch_from_torch()
model = LlamaForCausalLM.from_pretrained("your_path_to_model")
tokenizer = LlamaTokenizer.from_pretrained("/home/lczyh/share/models/llama-7b-hf")

View File

@ -36,7 +36,7 @@ from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
import colossalai
from transformers import LlamaForCausalLM, LlamaTokenizer
colossalai.launch_from_torch(config={})
colossalai.launch_from_torch()
model = LlamaForCausalLM.from_pretrained("/path/to/model")
tokenizer = LlamaTokenizer.from_pretrained("/path/to/model")
@ -57,27 +57,27 @@ We conducted multiple benchmark tests to evaluate the performance. We compared t
### Llama Throughput (tokens/s) | input length=1024, output length=128
#### A10 7b, fp16
| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(8) | 32(8) | 32(16)|
| :---: | :---: | :---: | :---: | :---: | :---: | :---:|
| Pipeline Inference | 40.35 | 77.1 | 139.03 | 232.7 | 257.81 | OOM |
| Hugging Face | 41.43 | 65.30 | 91.93 | 114.62 | OOM| OOM |
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(8) | 32(16) |
|:----------------------------:|:-----:|:-----:|:------:|:------:|:------:|:------:|
| Pipeline Inference | 40.35 | 77.1 | 139.03 | 232.7 | 257.81 | OOM |
| Hugging Face | 41.43 | 65.30 | 91.93 | 114.62 | OOM | OOM |
#### A10 13b, fp16
| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(4) |
| :---: | :---: | :---: | :---: | :---: |
| Pipeline Inference | 25.39 | 47.09 | 83.7 | 89.46 |
| Hugging Face | 23.48 | 37.59 | 53.44 | OOM |
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(4) |
|:----------------------------:|:-----:|:-----:|:-----:|:-----:|
| Pipeline Inference | 25.39 | 47.09 | 83.7 | 89.46 |
| Hugging Face | 23.48 | 37.59 | 53.44 | OOM |
#### A800 7b, fp16
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) |
| :---: | :---: | :---: | :---: | :---: | :---: |
| Pipeline Inference| 57.97 | 110.13 | 213.33 | 389.86 | 670.12 |
| Hugging Face | 42.44 | 76.5 | 151.97 | 212.88 | 256.13 |
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) |
|:----------------------------:|:-----:|:------:|:------:|:------:|:------:|
| Pipeline Inference | 57.97 | 110.13 | 213.33 | 389.86 | 670.12 |
| Hugging Face | 42.44 | 76.5 | 151.97 | 212.88 | 256.13 |
#### A800 13b, fp16
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) |
| :---: | :---: | :---: | :---: | :---: | :---: |
| Pipeline Inference | 41.78 | 94.18 | 172.67| 310.75| 470.15 |
| Hugging Face | 36.57 | 68.4 | 105.81 | 139.51 | 166.34 |
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) |
|:----------------------------:|:-----:|:-----:|:------:|:------:|:------:|
| Pipeline Inference | 41.78 | 94.18 | 172.67 | 310.75 | 470.15 |
| Hugging Face | 36.57 | 68.4 | 105.81 | 139.51 | 166.34 |

View File

@ -12,7 +12,7 @@ from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
GIGABYTE = 1024**3
MEGABYTE = 1024 * 1024
colossalai.launch_from_torch(config={})
colossalai.launch_from_torch()
def data_gen(batch_size: int = 4, seq_len: int = 512):

View File

@ -56,7 +56,7 @@ class Worker:
# initialize a ray collective group, otherwise colossalai distributed env won't be built successfully
collective.init_collective_group(world_size, rank, "nccl", "default")
# initialize and set distributed environment
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
ray_serve_logger.info(f"Worker with rank {rank} (world size {world_size}) setting up..")
log_cuda_info("Worker.setup")

View File

@ -98,7 +98,7 @@ class ColossalInferenceHandler(BaseHandler, ABC):
self.model.cuda()
self.model.eval()
colossalai.launch(config={}, rank=rank, world_size=world_size, host=host, port=port, backend="nccl")
colossalai.launch(rank=rank, world_size=world_size, host=host, port=port, backend="nccl")
logger.info("Initializing TPInferEngine ...")
shard_config = ShardConfig(
enable_tensor_parallelism=True if self.tp_size > 1 else False, extra_kwargs={"inference_only": True}

View File

@ -114,7 +114,7 @@ def run_worker(rank, args, master_func):
port = args.master_port
backend = "nccl" if device == "cuda" else "gloo"
launch(dict(), rank, world_size, host, int(port), backend, verbose=False)
launch(rank, world_size, host, int(port), backend, verbose=False)
ppg.set_global_info(
rank=rank,
world_size=world_size,

View File

@ -8,7 +8,7 @@ Licensed under the MIT License.
"""
import torch
from colossalai.utils import multi_tensor_applier
from colossalai.utils import get_current_device, multi_tensor_applier
class FusedAdam(torch.optim.Optimizer):
@ -75,7 +75,7 @@ class FusedAdam(torch.optim.Optimizer):
fused_optim = FusedOptimizerLoader().load()
# Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=get_current_device())
self.multi_tensor_adam = fused_optim.multi_tensor_adam
else:
raise RuntimeError("FusedAdam requires cuda extensions")

View File

@ -3,7 +3,7 @@ from typing import Any, Optional
import torch
from colossalai.kernel.kernel_loader import FusedOptimizerLoader
from colossalai.utils import multi_tensor_applier
from colossalai.utils import get_current_device, multi_tensor_applier
from .cpu_adam import CPUAdam
@ -87,7 +87,7 @@ class HybridAdam(CPUAdam):
if torch.cuda.is_available():
fused_optim = FusedOptimizerLoader().load()
self.gpu_adam_op = fused_optim.multi_tensor_adam
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=get_current_device())
@torch.no_grad()
def step(self, closure=None, div_scale: float = -1):

View File

@ -45,6 +45,18 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
return unpickle
def check_for_nccl_backend(group):
pg = group or c10d._get_default_group()
# Gate PG wrapper check on Gloo availability.
if c10d._GLOO_AVAILABLE:
# It is not expected for PG to be wrapped many times, but support it just
# in case
while isinstance(pg, c10d._ProcessGroupWrapper):
pg = pg.wrapped_pg
return c10d.is_nccl_available() and pg.name() == c10d.Backend.NCCL
# NOTE: FIXME: NPU DOES NOT support isend nor irecv, so broadcast is kept for future use
def _broadcast_object_list(
object_list: List[Any], src: int, group: ProcessGroup, device: Optional[Union[torch.device, str, int]] = None

View File

@ -7,7 +7,7 @@ from torch.nn import Module
from torch.utils._pytree import tree_map
from colossalai.accelerator import get_accelerator
from colossalai.interface import OptimizerWrapper
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils import get_current_device
@ -327,7 +327,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
self.send_forward(output_obj)
if outputs is not None:
outputs = merge_batch(outputs)
if isinstance(model, ModelWrapper):
model = model.unwrap()
batch_size_dim = getattr(model, "batch_size_dim", 0)
outputs = merge_batch(outputs, batch_size_dim)
return {"loss": accum_loss, "outputs": outputs}
def run_forward_backward(
@ -410,7 +413,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
if outputs is not None:
outputs = merge_batch(outputs)
if isinstance(model, ModelWrapper):
model = model.unwrap()
batch_size_dim = getattr(model, "batch_size_dim", 0)
outputs = merge_batch(outputs, batch_size_dim)
return {"loss": accum_loss, "outputs": outputs}
def forward_backward_step(

View File

@ -27,16 +27,18 @@ class PipelineStageManager:
pipeline_axis: int,
enable_interleave: bool = False,
num_model_chunks: int = 1,
num_layers_per_stage: Optional[List[int]] = None,
) -> None:
assert enable_interleave or num_model_chunks == 1, "num_model_chunks must be 1 when enable_interleave is False"
self.num_layers_per_stage = None
self.pg_mesh = pg_mesh
self.pipeline_axis = pipeline_axis
self.prev_rank: Optional[Tuple[int, ...]] = None
self.next_rank: Optional[Tuple[int, ...]] = None
self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {}
if num_layers_per_stage is not None:
assert len(num_layers_per_stage) == self.num_stages
self.num_layers_per_stage = num_layers_per_stage
# init prev and next coord
coord = self.pg_mesh.coordinate()
@ -56,6 +58,8 @@ class PipelineStageManager:
self.p2p_groups[tuple(ranks_in_group)] = group
self.is_interleave = enable_interleave
# for interleaved pipeline parallel, each device is responsible for multiple chunk of layers
self.num_model_chunks: int = num_model_chunks
if enable_interleave:
# use circle p2p communication
# add the process group of the first rank and the last rank
@ -64,59 +68,11 @@ class PipelineStageManager:
ranks_in_group = self.pg_mesh.get_ranks_in_group(group)
self.p2p_groups[tuple(ranks_in_group)] = group
# for interleaved pipeline parallel, each device is responsible for multiple chunk of layers
self.num_model_chunks: int = num_model_chunks
# for shardformer, hold stage indices of model
self.stage_indices: List[Tuple[int, int]]
# for shardformer, hold model chunk id
self.model_chunk_id: Optional[int] = None
@property
def control_distribute_layers(self) -> bool:
return self.num_layers_per_stage is not None
def set_distribution_config(self, num_model_layers: int, num_layers_per_stage: List[int]) -> None:
"""Set the distribution configuration.
This allows user to customize the number of layers for each stage.
Args:
num_model_layers (int): Number of layers in the model.
num_layers_per_stage (List[int]): Number of layers for each stage.
"""
assert all([0 < num_layers < num_model_layers for num_layers in num_layers_per_stage])
assert sum(num_layers_per_stage) == num_model_layers
assert len(num_layers_per_stage) == self.num_stages * (self.num_model_chunks if self.is_interleave else 1)
self.num_model_layers = num_model_layers
self.num_layers_per_stage = num_layers_per_stage
def distribute_layers(
self, num_layers: int, num_stages: Optional[int] = None, num_model_chunks: Optional[int] = None
) -> List[int]:
"""Divide layers into stages"""
num_stages = self.num_stages if num_stages is None else num_stages
num_model_chunks = (
(self.num_model_chunks if self.is_interleave else 1) if num_model_chunks is None else num_model_chunks
)
if self.control_distribute_layers:
assert num_layers == self.num_model_layers
return self.num_layers_per_stage
else:
quotient = num_layers // (num_stages * num_model_chunks)
remainder = num_layers % (num_stages * num_model_chunks)
# calculate the num_layers per stage
layers_per_stage = [quotient] * num_stages * num_model_chunks
# deal with the rest layers
if remainder > 0:
start_position = (num_stages * num_model_chunks) // 2 - remainder // 2
for i in range(start_position, start_position + remainder):
layers_per_stage[i] += 1
return layers_per_stage
def get_stage_index(
self,
layers_per_stage: List[int],
@ -139,9 +95,7 @@ class PipelineStageManager:
"""
stage = self.stage if stage is None else stage
num_model_chunks = (
(self.num_model_chunks if self.is_interleave else 1) if num_model_chunks is None else num_model_chunks
)
num_model_chunks = self.num_model_chunks if num_model_chunks is None else num_model_chunks
num_stages = self.num_stages if num_stages is None else num_stages
num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)
@ -261,3 +215,25 @@ class PipelineStageManager:
self.model_chunk_id = model_chunk_id
yield
self.model_chunk_id = old_model_chunk_id
def distribute_layers(
self, num_layers: int, num_stages: Optional[int] = None, num_model_chunks: Optional[int] = None
) -> List[int]:
if self.num_layers_per_stage is not None:
assert sum(self.num_layers_per_stage) == num_layers
return self.num_layers_per_stage
num_stages = self.num_stages if num_stages is None else num_stages
num_model_chunks = self.num_model_chunks if num_model_chunks is None else num_model_chunks
quotient = num_layers // (num_stages * num_model_chunks)
remainder = num_layers % (num_stages * num_model_chunks)
# calculate the num_layers per stage
layers_per_stage = [quotient] * num_stages * num_model_chunks
# deal with the rest layers
if remainder > 0:
start_position = (num_stages * num_model_chunks) // 2 - remainder // 2
for i in range(start_position, start_position + remainder):
layers_per_stage[i] += 1
return layers_per_stage

View File

@ -0,0 +1,7 @@
from .bnb import quantize_model
from .bnb_config import BnbQuantizationConfig
__all__ = [
"BnbQuantizationConfig",
"quantize_model",
]

View File

@ -0,0 +1,321 @@
# adapted from Hugging Face accelerate/utils/bnb.py accelerate/utils/modeling.py
import logging
import torch
import torch.nn as nn
from .bnb_config import BnbQuantizationConfig
try:
import bitsandbytes as bnb
IS_4BIT_BNB_AVAILABLE = bnb.__version__ >= "0.39.0"
IS_8BIT_BNB_AVAILABLE = bnb.__version__ >= "0.37.2"
except ImportError:
pass
logger = logging.getLogger(__name__)
def quantize_model(
model: torch.nn.Module,
bnb_quantization_config: BnbQuantizationConfig,
):
"""
This function will quantize the input loaded model with the associated config passed in `bnb_quantization_config`.
We will quantize the model and put the model on the GPU.
Args:
model (`torch.nn.Module`):
Input model. The model already loaded
bnb_quantization_config (`BnbQuantizationConfig`):
The bitsandbytes quantization parameters
Returns:
`torch.nn.Module`: The quantized model
"""
load_in_4bit = bnb_quantization_config.load_in_4bit
load_in_8bit = bnb_quantization_config.load_in_8bit
if load_in_8bit and not IS_8BIT_BNB_AVAILABLE:
raise ImportError(
"You have a version of `bitsandbytes` that is not compatible with 8bit quantization,"
" make sure you have the latest version of `bitsandbytes` installed."
)
if load_in_4bit and not IS_4BIT_BNB_AVAILABLE:
raise ValueError(
"You have a version of `bitsandbytes` that is not compatible with 4bit quantization,"
"make sure you have the latest version of `bitsandbytes` installed."
)
# We keep some modules such as the lm_head in their original dtype for numerical stability reasons
if bnb_quantization_config.skip_modules is None:
bnb_quantization_config.skip_modules = get_keys_to_not_convert(model)
modules_to_not_convert = bnb_quantization_config.skip_modules
# We add the modules we want to keep in full precision
if bnb_quantization_config.keep_in_fp32_modules is None:
bnb_quantization_config.keep_in_fp32_modules = []
keep_in_fp32_modules = bnb_quantization_config.keep_in_fp32_modules
# compatibility with peft
model.is_loaded_in_4bit = load_in_4bit
model.is_loaded_in_8bit = load_in_8bit
# assert model_device is cuda
model_device = next(model.parameters()).device
model = replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=modules_to_not_convert)
# convert param to the right dtype
dtype = bnb_quantization_config.torch_dtype
for name, param in model.state_dict().items():
if any(module_to_keep_in_fp32 in name for module_to_keep_in_fp32 in keep_in_fp32_modules):
param.to(torch.float32)
if param.dtype != torch.float32:
name = name.replace(".weight", "").replace(".bias", "")
param = getattr(model, name, None)
if param is not None:
param.to(torch.float32)
elif torch.is_floating_point(param):
param.to(dtype)
if model_device.type == "cuda":
# move everything to cpu in the first place because we can't do quantization if the weights are already on cuda
model.cuda(torch.cuda.current_device())
torch.cuda.empty_cache()
elif torch.cuda.is_available():
model.to(torch.cuda.current_device())
logger.info(
f"The model device type is {model_device.type}. However, cuda is needed for quantization."
"We move the model to cuda."
)
else:
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
return model
def replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=None, current_key_name=None):
"""
A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules or by `bnb.nn.Linear4bit`
modules from the `bitsandbytes`library. The function will be run recursively and replace `torch.nn.Linear` modules.
Parameters:
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
modules_to_not_convert (`List[str]`):
Names of the modules to not quantize convert. In practice we keep the `lm_head` in full precision for
numerical stability reasons.
current_key_name (`List[str]`, *optional*):
An array to track the current key of the recursion. This is used to check whether the current key (part of
it) is not in the list of modules to not convert.
"""
if modules_to_not_convert is None:
modules_to_not_convert = []
model, has_been_replaced = _replace_with_bnb_layers(
model, bnb_quantization_config, modules_to_not_convert, current_key_name
)
if not has_been_replaced:
logger.warning(
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
" this can happen for some architectures such as gpt2 that uses Conv1D instead of Linear layers."
" Please double check your model architecture, or submit an issue on github if you think this is"
" a bug."
)
return model
def _replace_with_bnb_layers(
model,
bnb_quantization_config,
modules_to_not_convert=None,
current_key_name=None,
):
"""
Private method that wraps the recursion for module replacement.
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
"""
# bitsandbytes will initialize CUDA on import, so it needs to be imported lazily
has_been_replaced = False
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert`
current_key_name_str = ".".join(current_key_name)
proceed = True
for key in modules_to_not_convert:
if (
(key in current_key_name_str) and (key + "." in current_key_name_str)
) or key == current_key_name_str:
proceed = False
break
if proceed:
# Load bnb module with empty weight and replace ``nn.Linear` module
if bnb_quantization_config.load_in_8bit:
bnb_module = bnb.nn.Linear8bitLt(
module.in_features,
module.out_features,
module.bias is not None,
has_fp16_weights=False,
threshold=bnb_quantization_config.llm_int8_threshold,
)
elif bnb_quantization_config.load_in_4bit:
bnb_module = bnb.nn.Linear4bit(
module.in_features,
module.out_features,
module.bias is not None,
bnb_quantization_config.bnb_4bit_compute_dtype,
compress_statistics=bnb_quantization_config.bnb_4bit_use_double_quant,
quant_type=bnb_quantization_config.bnb_4bit_quant_type,
)
else:
raise ValueError("load_in_8bit and load_in_4bit can't be both False")
bnb_module.weight.data = module.weight.data
bnb_module.weight.skip_zero_check = True
if module.bias is not None:
bnb_module.bias.data = module.bias.data
bnb_module.bias.skip_zero_check = True
bnb_module.requires_grad_(False)
setattr(model, name, bnb_module)
has_been_replaced = True
if len(list(module.children())) > 0:
_, _has_been_replaced = _replace_with_bnb_layers(
module, bnb_quantization_config, modules_to_not_convert, current_key_name
)
has_been_replaced = has_been_replaced | _has_been_replaced
# Remove the last key for recursion
current_key_name.pop(-1)
return model, has_been_replaced
def get_keys_to_not_convert(model):
r"""
An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules
we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want
to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in
int8.
Parameters:
model (`torch.nn.Module`):
Input model
"""
# Create a copy of the model
# with init_empty_weights():
# tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager`
tied_model = model
tied_params = find_tied_parameters(tied_model)
# For compatibility with Accelerate < 0.18
if isinstance(tied_params, dict):
tied_keys = sum(list(tied_params.values()), []) + list(tied_params.keys())
else:
tied_keys = sum(tied_params, [])
has_tied_params = len(tied_keys) > 0
# Check if it is a base model
is_base_model = False
if hasattr(model, "base_model_prefix"):
is_base_model = not hasattr(model, model.base_model_prefix)
# Ignore this for base models (BertModel, GPT2Model, etc.)
if (not has_tied_params) and is_base_model:
return []
# otherwise they have an attached head
list_modules = list(model.named_children())
list_last_module = [list_modules[-1][0]]
# add last module together with tied weights
intersection = set(list_last_module) - set(tied_keys)
list_untouched = list(set(tied_keys)) + list(intersection)
# remove ".weight" from the keys
names_to_remove = [".weight", ".bias"]
filtered_module_names = []
for name in list_untouched:
for name_to_remove in names_to_remove:
if name_to_remove in name:
name = name.replace(name_to_remove, "")
filtered_module_names.append(name)
return filtered_module_names
def find_tied_parameters(model: nn.Module, **kwargs):
"""
Find the tied parameters in a given model.
<Tip warning={true}>
The signature accepts keyword arguments, but they are for the recursive part of this function and you should ignore
them.
</Tip>
Args:
model (`torch.nn.Module`): The model to inspect.
Returns:
List[List[str]]: A list of lists of parameter names being all tied together.
Example:
```py
>>> from collections import OrderedDict
>>> import torch.nn as nn
>>> model = nn.Sequential(OrderedDict([("linear1", nn.Linear(4, 4)), ("linear2", nn.Linear(4, 4))]))
>>> model.linear2.weight = model.linear1.weight
>>> find_tied_parameters(model)
[['linear1.weight', 'linear2.weight']]
```
"""
# Initialize result and named_parameters before recursing.
named_parameters = kwargs.get("named_parameters", None)
prefix = kwargs.get("prefix", "")
result = kwargs.get("result", {})
if named_parameters is None:
named_parameters = {n: p for n, p in model.named_parameters()}
else:
# A tied parameter will not be in the full `named_parameters` seen above but will be in the `named_parameters`
# of the submodule it belongs to. So while recursing we track the names that are not in the initial
# `named_parameters`.
for name, parameter in model.named_parameters():
full_name = name if prefix == "" else f"{prefix}.{name}"
if full_name not in named_parameters:
# When we find one, it has to be one of the existing parameters.
for new_name, new_param in named_parameters.items():
if new_param is parameter:
if new_name not in result:
result[new_name] = []
result[new_name].append(full_name)
# Once we have treated direct parameters, we move to the child modules.
for name, child in model.named_children():
child_name = name if prefix == "" else f"{prefix}.{name}"
find_tied_parameters(child, named_parameters=named_parameters, prefix=child_name, result=result)
return FindTiedParametersResult([sorted([weight] + list(set(tied))) for weight, tied in result.items()])
class FindTiedParametersResult(list):
"""
This is a subclass of a list to handle backward compatibility for Transformers. Do not rely on the fact this is not
a list or on the `values` method as in the future this will be removed.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def values(self):
return sum([x[1:] for x in self], [])

View File

@ -0,0 +1,113 @@
# adapted from Hugging Face accelerate/utils/dataclasses.py
import warnings
from dataclasses import dataclass, field
from typing import List
import torch
@dataclass
class BnbQuantizationConfig:
"""
A plugin to enable BitsAndBytes 4bit and 8bit quantization
"""
load_in_8bit: bool = field(default=False, metadata={"help": "enable 8bit quantization."})
llm_int8_threshold: float = field(
default=6.0, metadata={"help": "value of the outliner threshold. only relevant when load_in_8bit=True"}
)
load_in_4bit: bool = field(default=False, metadata={"help": "enable 4bit quantization."})
bnb_4bit_quant_type: str = field(
default="fp4",
metadata={
"help": "set the quantization data type in the `bnb.nn.Linear4Bit` layers. Options are {'fp4','np4'}."
},
)
bnb_4bit_use_double_quant: bool = field(
default=False,
metadata={
"help": "enable nested quantization where the quantization constants from the first quantization are quantized again."
},
)
bnb_4bit_compute_dtype: bool = field(
default="fp16",
metadata={
"help": "This sets the computational type which might be different than the input time. For example, inputs might be "
"fp32, but computation can be set to bf16 for speedups. Options are {'fp32','fp16','bf16'}."
},
)
torch_dtype: torch.dtype = field(
default=None,
metadata={
"help": "this sets the dtype of the remaining non quantized layers. `bitsandbytes` library suggests to set the value"
"to `torch.float16` for 8 bit model and use the same dtype as the compute dtype for 4 bit model "
},
)
skip_modules: List[str] = field(
default=None,
metadata={
"help": "an explicit list of the modules that we don't quantize. The dtype of these modules will be `torch_dtype`."
},
)
keep_in_fp32_modules: List[str] = field(
default=None,
metadata={"help": "an explicit list of the modules that we don't quantize. We keep them in `torch.float32`."},
)
def __post_init__(self):
if isinstance(self.bnb_4bit_compute_dtype, str):
if self.bnb_4bit_compute_dtype == "fp32":
self.bnb_4bit_compute_dtype = torch.float32
elif self.bnb_4bit_compute_dtype == "fp16":
self.bnb_4bit_compute_dtype = torch.float16
elif self.bnb_4bit_compute_dtype == "bf16":
self.bnb_4bit_compute_dtype = torch.bfloat16
else:
raise ValueError(
f"bnb_4bit_compute_dtype must be in ['fp32','fp16','bf16'] but found {self.bnb_4bit_compute_dtype}"
)
elif not isinstance(self.bnb_4bit_compute_dtype, torch.dtype):
raise ValueError("bnb_4bit_compute_dtype must be a string or a torch.dtype")
if self.skip_modules is not None and not isinstance(self.skip_modules, list):
raise ValueError("skip_modules must be a list of strings")
if self.keep_in_fp32_modules is not None and not isinstance(self.keep_in_fp32_modules, list):
raise ValueError("keep_in_fp_32_modules must be a list of strings")
if self.load_in_4bit:
self.target_dtype = "int4"
if self.load_in_8bit:
self.target_dtype = torch.int8
if self.load_in_4bit and self.llm_int8_threshold != 6.0:
warnings.warn("llm_int8_threshold can only be used for model loaded in 8bit")
if isinstance(self.torch_dtype, str):
if self.torch_dtype == "fp32":
self.torch_dtype = torch.float32
elif self.torch_dtype == "fp16":
self.torch_dtype = torch.float16
elif self.torch_dtype == "bf16":
self.torch_dtype = torch.bfloat16
else:
raise ValueError(f"torch_dtype must be in ['fp32','fp16','bf16'] but found {self.torch_dtype}")
if self.load_in_8bit and self.torch_dtype is None:
self.torch_dtype = torch.float16
if self.load_in_4bit and self.torch_dtype is None:
self.torch_dtype = self.bnb_4bit_compute_dtype
if not isinstance(self.torch_dtype, torch.dtype):
raise ValueError("torch_dtype must be a torch.dtype")

View File

@ -38,7 +38,7 @@ from transformers import BertForMaskedLM
import colossalai
# launch colossalai
colossalai.launch_from_torch(config={})
colossalai.launch_from_torch()
# create model
config = BertConfig.from_pretrained('bert-base-uncased')
@ -114,30 +114,30 @@ We will follow this roadmap to develop Shardformer:
- [x] Unit Testing
- [ ] Policy Implementation
| model | tensor parallel | pipeline parallel | lazy initialization | xformer | flash attn2 | jit fused operator | fused layernorm | sequence parallel | overlap |
| :------: | :-----: | :-----: | :--------: | :---------: | :------: | :-----: | :-----: | :--------: | :---------: |
| bert | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
| t5 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| llama V1/V2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| gpt2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
| opt | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| bloom | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
| chatglm2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
| vit | [√] | [√] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| whisper | [√] | [√] | [√] | [√] | [√] | [ ] | [√] | [ ] | [ ] |
| sam | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| blip2 | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| falcon | [√] | [√] | [√] | [√] | [√] | [ ] | [√] | [ ] | [ ] |
| roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| albert | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| ernie | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| gpt-neo | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| gpt-j | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| beit | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| swin | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| swin V2 | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| qwen | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| mistral | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| model | tensor parallel | pipeline parallel | lazy initialization | xformer | flash attn2 | jit fused operator | fused layernorm | sequence parallel | overlap |
|:-----------:|:---------------:|:-----------------:|:-------------------:|:-------:|:-----------:|:------------------:|:---------------:|:-----------------:|:-------:|
| bert | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
| t5 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| llama V1/V2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| gpt2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
| opt | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| bloom | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
| chatglm2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
| vit | [√] | [√] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| whisper | [√] | [√] | [√] | [√] | [√] | [ ] | [√] | [ ] | [ ] |
| sam | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| blip2 | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| falcon | [√] | [√] | [√] | [√] | [√] | [ ] | [√] | [ ] | [ ] |
| roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| albert | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| ernie | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| gpt-neo | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| gpt-j | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| beit | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| swin | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| swin V2 | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| qwen | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| mistral | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
## 💡 API Design
@ -391,6 +391,43 @@ _POLICY_LIST = {
}
```
#### How to support those models in huggingface model hub but not in the transformers library
There are two cases:
1. the modeling file is in the `transformers` library but the model weight is not in the `transformers` library. E.g. model structure of "01-ai/Yi-34B" is the same as LLaMA but the weight is not in the `transformers` library. In this case, we should support llama as usual and Yi-34B is also supported by the llama policy. We do not need to add a new policy for Yi-34B.
2. the modeling file is not in the `transformers` library, such as the "THUDM/chatglm2-6b".
Take "THUDM/chatglm2-6b" as an example, we clearly illustrate how to support this model in the `shardformer`.
Unlike llama which is in `transformers` library, we cannot import chatglm2 model directly. Thus, the key in policy should be str of class name, rather than class itself.
E.g. for llama:
```python
policy[LlamaDecoderLayer] = ModulePolicyDescription(...)
```
for chatglm2:
```python
policy["GLMBlock"] = ModulePolicyDescription(...)
```
Then when registering such models in the autopolicy, we should follow below format:
```python
"transformers_modules.<modeling_filename>.<class_name>": PolicyLocation(
file_name="<policy_filename>", class_name="<policy_class_name>"
)
```
As for chatglm2 model, it should be:
```python
"transformers_modules.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"
)
```
When using such models, `AutoModel` is supported as usual. The policy will be automatically loaded by the autopolicy.
### Write Your Unit Testing
This section serves as the guideline for testing the `shardformer` module.
@ -424,13 +461,13 @@ We conducted [benchmark tests](./examples/performance_benchmark.py) to evaluate
We set the batch size to 4, the number of attention heads to 8, and the head dimension to 64. 'N_CTX' refers to the sequence length.
In the case of using 2 GPUs, the training times are as follows.
| N_CTX | org_model | shard_model |
| :------: | :-----: | :-----: |
| 256 | 11.2ms | 17.2ms |
| 512 | 9.8ms | 19.5ms |
| 1024 | 19.6ms | 18.9ms |
| 2048 | 46.6ms | 30.8ms |
| 4096 | 160.5ms | 90.4ms |
| N_CTX | org_model | shard_model |
|:-----:|:---------:|:-----------:|
| 256 | 11.2ms | 17.2ms |
| 512 | 9.8ms | 19.5ms |
| 1024 | 19.6ms | 18.9ms |
| 2048 | 46.6ms | 30.8ms |
| 4096 | 160.5ms | 90.4ms |
<p align="center">
@ -440,13 +477,13 @@ In the case of using 2 GPUs, the training times are as follows.
In the case of using 4 GPUs, the training times are as follows.
| N_CTX | org_model | shard_model |
| :------: | :-----: | :-----: |
| 256 | 10.0ms | 21.1ms |
| 512 | 11.5ms | 20.2ms |
| 1024 | 22.1ms | 20.6ms |
| 2048 | 46.9ms | 24.8ms |
| 4096 | 160.4ms | 68.0ms |
| N_CTX | org_model | shard_model |
|:-----:|:---------:|:-----------:|
| 256 | 10.0ms | 21.1ms |
| 512 | 11.5ms | 20.2ms |
| 1024 | 22.1ms | 20.6ms |
| 2048 | 46.9ms | 24.8ms |
| 4096 | 160.4ms | 68.0ms |
@ -475,10 +512,10 @@ warmup_fraction = 0.03
| accuracy | f1 | loss | GPU number | model sharded |
| :------: | :-----: | :-----: | :--------: | :---------: |
| 0.82971 | 0.87713 | 0.23194 | 4 | True |
| 0.83797 | 0.88006 | 0.22683 | 2 | True |
| 0.84521 | 0.88700 | 0.21822 | 1 | False |
|:--------:|:-------:|:-------:|:----------:|:-------------:|
| 0.82971 | 0.87713 | 0.23194 | 4 | True |
| 0.83797 | 0.88006 | 0.22683 | 2 | True |
| 0.84521 | 0.88700 | 0.21822 | 1 | False |
Overall, the results demonstrate that using shardformers during model training does not affect the convergence.

View File

@ -28,7 +28,7 @@ def to_device(x: Any, device: torch.device) -> Any:
def train(args):
colossalai.launch_from_torch(config={}, seed=42)
colossalai.launch_from_torch(seed=42)
coordinator = DistCoordinator()
# prepare for data and dataset

View File

@ -1,6 +1,7 @@
"""
Shardformer Benchmark
"""
import torch
import torch.distributed as dist
import transformers
@ -84,5 +85,5 @@ def bench_shardformer(BATCH, N_CTX, provider, model_func, dtype=torch.float32, d
# start benchmark, command:
# torchrun --standalone --nproc_per_node=2 performance_benchmark.py
if __name__ == "__main__":
colossalai.launch_from_torch({})
colossalai.launch_from_torch()
bench_shardformer.run(save_path=".", print_data=dist.get_rank() == 0)

View File

@ -1,8 +1,8 @@
from ._operation import all_to_all_comm
from .attn import AttnMaskType, ColoAttention
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D
from .loss import cross_entropy_1d
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .parallel_module import ParallelModule
@ -25,6 +25,9 @@ __all__ = [
"FusedRMSNorm",
"FusedLinear1D_Col",
"ParallelModule",
"PaddingEmbedding",
"PaddingLMHead",
"VocabParallelLMHead1D",
"AttnMaskType",
"ColoAttention",
"all_to_all_comm",

View File

@ -8,7 +8,6 @@ from colossalai.kernel.kernel_loader import (
FlashAttentionForFloatAndCustomMaskLoader,
FlashAttentionLoader,
FlashAttentionWithCustomMaskLoader,
FlashAttentionWithPaddingMaskLoader,
KernelLoader,
)
@ -65,15 +64,17 @@ class ColoAttention:
half_dispatch_map = {
None: FlashAttentionLoader(),
AttnMaskType.CUSTOM: FlashAttentionWithCustomMaskLoader(),
AttnMaskType.PADDED: FlashAttentionWithPaddingMaskLoader(),
AttnMaskType.PADDED: FlashAttentionLoader(),
AttnMaskType.CAUSAL: FlashAttentionLoader(),
AttnMaskType.PADDED_CAUSAL: FlashAttentionWithPaddingMaskLoader(),
AttnMaskType.PADDED_CAUSAL: FlashAttentionLoader(),
}
# fp32
float_dispatch_map = {
None: FlashAttentionForFloatAndCustomMaskLoader(),
AttnMaskType.CUSTOM: FlashAttentionForFloatAndCustomMaskLoader(),
AttnMaskType.PADDED: FlashAttentionForFloatAndCustomMaskLoader(),
AttnMaskType.CAUSAL: FlashAttentionForFloatAndCustomMaskLoader(),
AttnMaskType.PADDED_CAUSAL: FlashAttentionForFloatAndCustomMaskLoader(),
}
ColoAttention._kernel_dispatch_map = {
torch.float16: half_dispatch_map,
@ -140,16 +141,22 @@ class ColoAttention:
outputs["attention_mask_type"] = AttnMaskType.CAUSAL
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device).tril(diagonal=0).expand(b, s_q, s_kv)
else:
assert q_padding_mask.shape == (
b,
s_q,
), f"q_padding_mask shape {q_padding_mask.shape} should be the same. ({shape_4d})"
max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask)
if kv_padding_mask is None:
# self attention
kv_padding_mask = q_padding_mask
assert q_padding_mask.shape == (b, s_q) and kv_padding_mask.shape == (
max_seqlen_kv, cu_seqlens_kv, kv_indices = max_seqlen_q, cu_seqlens_q, q_indices
else:
max_seqlen_kv, cu_seqlens_kv, kv_indices = get_pad_info(kv_padding_mask)
assert kv_padding_mask.shape == (
b,
s_kv,
), f"q_padding_mask shape {q_padding_mask.shape} and kv_padding_mask shape {kv_padding_mask.shape} should be the same. ({shape_4d})"
attention_mask = torch.einsum("bi,bj->bij", q_padding_mask, kv_padding_mask).to(dtype=dtype, device=device)
max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask)
max_seqlen_kv, cu_seqlens_kv, kv_indices = get_pad_info(kv_padding_mask)
), f"q_padding_mask shape {kv_padding_mask.shape} should be the same. ({shape_4d})"
attention_mask = q_padding_mask[:, None, :].expand(b, s_kv, s_q).to(dtype=dtype, device=device)
outputs.update(
{
"cu_seqlens_q": cu_seqlens_q,

View File

@ -21,10 +21,10 @@ from colossalai.tensor.d_tensor.api import (
)
from ._operation import gather_forward_split_backward, reduce_forward
from .parallel_module import ParallelModule
from .parallel_module import PaddingParallelModule, ParallelModule
from .utils import create_randomizer_with_offset
__all__ = ["Embedding1D", "VocabParallelEmbedding1D"]
__all__ = ["Embedding1D", "VocabParallelEmbedding1D", "PaddingEmbedding"]
class Embedding1D(ParallelModule):
@ -161,7 +161,80 @@ class Embedding1D(ParallelModule):
return output_parallel
class VocabParallelEmbedding1D(ParallelModule):
class PaddingEmbedding(PaddingParallelModule):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: torch.dtype = None,
device: torch.device = None,
weight: Optional[nn.Parameter] = None,
make_vocab_size_divisible_by: int = 64,
*args,
**kwargs,
):
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.embed_args = args
self.embed_kwargs = kwargs
self.padding_idx = padding_idx
if num_embeddings % make_vocab_size_divisible_by != 0:
self.num_embeddings = (
num_embeddings + make_vocab_size_divisible_by - (num_embeddings % make_vocab_size_divisible_by)
)
# create weight and bias
if weight is None:
factory_kwargs = {"device": device, "dtype": dtype}
weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)
super().__init__(self.num_embeddings, num_embeddings, weight)
if weight is None:
self.reset_parameters()
def reset_parameters(self) -> None:
init.normal_(self.weight)
self._fill_padding_idx_with_zero()
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)
def forward(self, input: Tensor) -> Tensor:
return F.embedding(input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
@staticmethod
def from_native_module(
module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> PaddingParallelModule:
r"""
Convert a native pytorch embedding module to a parallel module.
"""
LazyInitContext.materialize(module)
# get the origin attributes
num_embeddings = module.num_embeddings
embedding_dim = module.embedding_dim
padding_idx = module.padding_idx
device = module.weight.device
# create the parallel module
padding_embedding = PaddingEmbedding(
num_embeddings=num_embeddings,
embedding_dim=embedding_dim,
padding_idx=padding_idx,
device=device,
weight=module.weight,
*args,
**kwargs,
)
return padding_embedding
class VocabParallelEmbedding1D(PaddingParallelModule):
r"""Embedding parallelized in the vocabulary dimension.
Args:
@ -201,10 +274,10 @@ class VocabParallelEmbedding1D(ParallelModule):
process_group: ProcessGroup = None,
weight: Optional[nn.Parameter] = None,
weight_initializer: Callable = init.normal_(),
make_vocab_size_divisible_by: int = 64,
*args,
**kwargs,
):
super().__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.embed_args = args
@ -214,8 +287,23 @@ class VocabParallelEmbedding1D(ParallelModule):
tensor_parallel_size = dist.get_world_size(group=process_group)
tensor_parallel_rank = dist.get_rank(group=process_group)
self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size)
self.num_embeddings = self.num_embeddings_per_partition
# generate weight and bias
if weight is None:
factory_kwargs = {"device": device, "dtype": dtype}
weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)
# calculate new padding size
multiple = make_vocab_size_divisible_by * tensor_parallel_size
if num_embeddings % multiple != 0:
self.num_embeddings = num_embeddings + multiple - (num_embeddings % multiple)
# resize vocabulary size
super().__init__(self.num_embeddings, num_embeddings, weight)
# deal with tensor parallelism
self.num_embeddings_per_partition = divide(self.num_embeddings, tensor_parallel_size)
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
@ -226,13 +314,6 @@ class VocabParallelEmbedding1D(ParallelModule):
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
# parameter
if weight is None:
factory_kwargs = {"device": device, "dtype": dtype}
self.weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)
self.weight = weight
if not is_distributed_tensor(self.weight):
sharded_weight = shard_rowwise(self.weight.data, process_group)
sharded_tensor_to_existing_param(sharded_weight, self.weight)
@ -243,7 +324,7 @@ class VocabParallelEmbedding1D(ParallelModule):
@staticmethod
def from_native_module(
module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
) -> PaddingParallelModule:
r"""
Convert a native pytorch embedding module to a parallel module.
"""
@ -303,11 +384,9 @@ class VocabParallelEmbedding1D(ParallelModule):
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
output_parallel = F.embedding(
masked_input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs
)
# Mask the output embedding.
embedding_output = output_parallel.clone()
embedding_output[input_mask, :] = 0.0

View File

@ -32,7 +32,7 @@ from ._operation import (
reducescatter_forward_gather_backward,
split_forward_gather_backward,
)
from .parallel_module import ParallelModule
from .parallel_module import PaddingParallelModule, ParallelModule
from .utils import create_randomizer_with_offset
__all__ = ["Linear1D_Col", "Linear1D_Row"]
@ -84,8 +84,9 @@ class Linear1D_Col(ParallelModule):
bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
**kwargs,
):
super().__init__()
super().__init__(weight=weight, bias_=bias_, **kwargs)
# Keep input parameters
self.in_features = in_features
@ -118,6 +119,7 @@ class Linear1D_Col(ParallelModule):
else:
weight.data = weight.data.to(device=device, dtype=dtype)
self.weight = weight
if not is_distributed_tensor(self.weight):
sharded_weight = shard_rowwise(self.weight.data, self.process_group)
sharded_tensor_to_existing_param(sharded_weight, self.weight)
@ -140,7 +142,7 @@ class Linear1D_Col(ParallelModule):
@staticmethod
def from_native_module(
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs
) -> ParallelModule:
r"""
Convert a native PyTorch linear layer to a parallelized linear layer.
@ -173,7 +175,6 @@ class Linear1D_Col(ParallelModule):
process_group=process_group,
weight=module.weight,
bias_=module.bias,
*args,
**kwargs,
)
@ -322,7 +323,7 @@ class Linear1D_Row(ParallelModule):
@staticmethod
def from_native_module(
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs
) -> ParallelModule:
r"""
Convert a native PyTorch linear layer to a parallelized linear layer.
@ -356,7 +357,6 @@ class Linear1D_Row(ParallelModule):
process_group=process_group,
weight=module.weight,
bias_=module.bias,
*args,
**kwargs,
)
@ -439,3 +439,211 @@ class Linear1D_Row(ParallelModule):
return output
else:
return output, self.bias
class PaddingLMHead(PaddingParallelModule):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
make_vocab_size_divisible_by: int = 64,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
):
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
if out_features % make_vocab_size_divisible_by != 0:
self.out_features = (
out_features + make_vocab_size_divisible_by - (out_features % make_vocab_size_divisible_by)
)
if weight is None:
factory_kwargs = {"device": device, "dtype": dtype}
weight = Parameter(torch.empty(out_features, self.in_features, **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)
if bias:
if bias_ is None:
self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
else:
bias_.data = bias_.data.to(device=device, dtype=dtype)
else:
bias_ = None
# resize embeddings
super().__init__(self.out_features, out_features, weight, bias_)
if weight is None:
self.reset_parameters(weight_initializer, bias_initializer)
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.out_features
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
@staticmethod
def from_native_module(
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs
) -> PaddingParallelModule:
r"""
Convert a native PyTorch linear layer to a parallelized linear layer.
"""
LazyInitContext.materialize(module)
# get the attributes
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
device = module.weight.device
# ensure only one process group is passed
lm_head_linear = PaddingLMHead(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
weight=module.weight,
bias_=module.bias,
**kwargs,
)
return lm_head_linear
def forward(self, input: Tensor) -> Tensor:
output = F.linear(input, self.weight, self.bias)
output = output[..., : self.old_num_embeddings]
return output
class VocabParallelLMHead1D(Linear1D_Col, PaddingParallelModule):
r"""Linear layer with column parallelism.
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
its second dimension as :math:`A = [A_1, ..., A_p]`.
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
device (`torch.device`): The device of parameters, defaults to None.
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
gather_output (bool, optional): If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (`typing.Callable`):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (`typing.Callable`):
The initializer of bias, defaults to xavier uniform initializer.
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
make_vocab_size_divisible_by: int = 64,
**kwargs,
):
# create weight and bias
if weight is None:
factory_kwargs = {"device": device, "dtype": dtype}
weight = Parameter(torch.empty(out_features, self.in_features, **factory_kwargs))
if bias:
if bias_ is None:
bias_ = Parameter(torch.empty(out_features, **factory_kwargs))
else:
bias_ = None
# calculate new vocab size
self.tensor_parallel_size = dist.get_world_size(group=process_group)
new_out_features = out_features
multiple = make_vocab_size_divisible_by * self.tensor_parallel_size
if out_features % multiple != 0:
new_out_features = out_features + multiple - (out_features % multiple)
super().__init__(
in_features=in_features,
out_features=new_out_features,
bias=bias,
device=device,
process_group=process_group,
weight=weight,
bias_=bias_,
**kwargs,
new_num_embeddings=new_out_features,
old_num_embeddings=out_features,
)
# get the length of valid embeddings
tp_rank = dist.get_rank(process_group)
partition_size = self.new_num_embeddings // dist.get_world_size(process_group)
if self.old_num_embeddings >= (tp_rank + 1) * partition_size:
self.num_valid_embeddings_local = partition_size
elif self.old_num_embeddings >= tp_rank * partition_size:
self.num_valid_embeddings_local = self.old_num_embeddings - tp_rank * partition_size
else:
self.num_valid_embeddings_local = 0
@staticmethod
def from_native_module(
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs
) -> PaddingParallelModule:
r"""
Convert a native PyTorch linear layer to a parallelized linear layer.
"""
LazyInitContext.materialize(module)
# get the attributes
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
device = module.weight.device
lm_head_linear = VocabParallelLMHead1D(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
process_group=process_group,
weight=module.weight,
bias_=module.bias,
**kwargs,
)
return lm_head_linear
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
# get forward output
if self.skip_bias_add:
output, bias = super().forward(input_)
else:
output = super().forward(input_)
# delete the padding of output
if self.gather_output:
output = output[..., : self.old_num_embeddings]
else:
output = output[..., : self.num_valid_embeddings_local]
# return
if self.skip_bias_add:
return output, bias
return output

View File

@ -15,7 +15,14 @@ class DistCrossEntropy(Function):
"""
@staticmethod
def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor, ignore_index: int, process_group: ProcessGroup):
def forward(
ctx,
vocab_logits: torch.Tensor,
target: torch.Tensor,
ignore_index: int,
process_group: ProcessGroup,
vocab_size: int,
):
r"""
Calculate the cross entropy loss before gather, the origin loss function is as follows:
loss = -log(exp(x[class])/sum(exp(x[i]))
@ -41,15 +48,21 @@ class DistCrossEntropy(Function):
vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1)
# mask the target in the local device
partition_vocab_size = vocab_logits.size()[-1]
rank = dist.get_rank(group=process_group)
world_size = dist.get_world_size(group=process_group)
global_vocab_size = partition_vocab_size * world_size
if vocab_size == None:
partition_vocab_size = vocab_logits.size()[-1]
global_vocab_size = partition_vocab_size * world_size
else:
global_vocab_size = vocab_size
partition_vocab_size = global_vocab_size // world_size
# [down, up) => false, other device and -100 => true
delta = (global_vocab_size + world_size - 1) // world_size
down_threshold = rank * delta
up_threshold = down_threshold + delta
if up_threshold > global_vocab_size:
up_threshold = global_vocab_size
mask = (target < down_threshold) | (target >= up_threshold)
masked_target = target.clone() - down_threshold
masked_target[mask] = 0
@ -57,7 +70,8 @@ class DistCrossEntropy(Function):
# reshape the logits and target
# reshape the vocab_logits to [bath_size * seq_len, vocab_size]
# reshape the labels to [bath_size * seq_len]
logits_2d = vocab_logits.view(-1, partition_vocab_size)
self_vocab_size = vocab_logits.size()[-1]
logits_2d = vocab_logits.view(-1, self_vocab_size)
masked_target_1d = masked_target.view(-1)
# extract the x[class] and set the x[other device] to zero
@ -104,10 +118,14 @@ class DistCrossEntropy(Function):
grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update
grad_logits.mul_(grad_output.unsqueeze(dim=-1))
return grad_logits, None, None, None
return grad_logits, None, None, None, None
def cross_entropy_1d(
vocab_logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100, process_group: ProcessGroup = None
vocab_logits: torch.Tensor,
labels: torch.Tensor,
ignore_index: int = -100,
process_group: ProcessGroup = None,
vocab_size: int = None,
) -> torch.Tensor:
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group)
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size)

View File

@ -225,7 +225,13 @@ class FusedLayerNorm(BaseLayerNorm):
# fall back to the normal fused layernorm is not built
ApexFusedLayerNorm = FusedLayerNormWithHook
else:
ApexFusedLayerNorm = FusedLayerNormWithHook
try:
ApexFusedLayerNorm = FusedLayerNormWithHook
except NameError:
warnings.warn(
"Please install Apex from source to use fused kernels, or set self.enable_fused_normalization = False. Using vanilla layernorm instead."
)
return module
layernorm = (
ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device)
@ -275,19 +281,16 @@ class FusedRMSNorm(BaseLayerNorm):
)
LazyInitContext.materialize(module)
# to check if it is huggingface LlamaRMSNorm or MistralRMSNorm
if module.__class__.__name__ in ["LlamaRMSNorm", "MistralRMSNorm"]:
normalized_shape = module.weight.shape[0]
eps = module.variance_epsilon
elementwise_affine = True
else:
# get the attributes of the module
normalized_shape = module.normalized_shape
eps = module.eps
elementwise_affine = module.elementwise_affine
# try to get normalized_shape, eps, elementwise_affine from the module
normalized_shape = getattr(module, "normalized_shape", module.weight.shape[0])
eps = module.variance_epsilon if hasattr(module, "variance_epsilon") else module.eps
elementwise_affine = getattr(module, "elementwise_affine", True)
rmsnorm = FusedRMSNormWithHook(
normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine
normalized_shape=normalized_shape,
eps=eps,
elementwise_affine=elementwise_affine,
)
rmsnorm.weight = module.weight

View File

@ -3,7 +3,7 @@
import itertools
from abc import ABC, abstractmethod
from typing import List, Union
from typing import List, Optional, Union
import torch
import torch.nn as nn
@ -20,11 +20,15 @@ from colossalai.tensor.d_tensor import (
is_distributed_tensor,
sharded_tensor_to_param,
)
from colossalai.tensor.padded_tensor import is_padded_tensor, to_padded_tensor, to_unpadded_tensor
__all__ = ["ParallelModule"]
class ParallelModule(nn.Module, ABC):
def __init__(self, **kwargs):
super().__init__()
@abstractmethod
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]] = None
@ -54,7 +58,7 @@ class ParallelModule(nn.Module, ABC):
"""
for name, param in self._parameters.items():
if param is not None:
destination[prefix + name] = gather_distributed_param(param, keep_vars=keep_vars)
destination[prefix + name] = gather_distributed_param(param, keep_vars=keep_vars).data
for name, buf in self._buffers.items():
if buf is not None and name not in self._non_persistent_buffers_set:
@ -171,3 +175,187 @@ class ParallelModule(nn.Module, ABC):
input_name = input_name.split(".", 1)[0] # get the name of param/buffer/child
if input_name not in self._modules and input_name not in local_state:
unexpected_keys.append(key)
class PaddingParallelModule(ParallelModule):
def __init__(
self,
new_num_embeddings: int,
old_num_embeddings: int,
weight: Optional[nn.Parameter],
bias_: Optional[nn.Parameter] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.new_num_embeddings = new_num_embeddings
self.old_num_embeddings = old_num_embeddings
self.weight = weight
self.bias = bias_
if not (is_distributed_tensor(self.weight) or self.weight.shape[0] == self.new_num_embeddings):
self.resize_embedding_weight()
if self.bias is not None and not (
is_distributed_tensor(self.bias) or self.bias.shape[0] == self.new_num_embeddings
):
self.resize_embedding_bias()
@abstractmethod
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]] = None
) -> "PaddingParallelModule":
"""
Convert a native PyTorch module to a parallelized module.
Args:
module (nn.Module): the module to be converted.
process_group (ProcessGroup or list[ProcessGroup]): the process group(s) to be used for communication.
If this is a list, the process group at the ith index of the list will correspond to the process group
in the ith axis of the device mesh. Defaults to None, which means the global process group.
"""
raise NotImplementedError
def _save_to_state_dict(self, destination, prefix, keep_vars):
r"""Saves module state to `destination` dictionary, containing a state
of the module, but not its descendants. This is called on every
submodule in :meth:`~torch.nn.Module.state_dict`.
In rare cases, subclasses can achieve class-specific behavior by
overriding this method with custom logic.
Args:
destination (dict): a dict where state will be stored
prefix (str): the prefix for parameters and buffers used in this
module
"""
for name, param in self._parameters.items():
if param is not None:
param = gather_distributed_param(param, keep_vars=keep_vars)
if is_padded_tensor(param):
param = to_unpadded_tensor(param)
destination[prefix + name] = param.data
for name, buf in self._buffers.items():
if buf is not None and name not in self._non_persistent_buffers_set:
destination[prefix + name] = buf if keep_vars else buf.detach()
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if getattr(self.__class__, "get_extra_state", Module.get_extra_state) is not Module.get_extra_state:
destination[extra_state_key] = self.get_extra_state()
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
r"""Copies parameters and buffers from :attr:`state_dict` into only
this module, but not its descendants. This is called on every submodule
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
For state dicts without metadata, :attr:`local_metadata` is empty.
Subclasses can achieve class-specific backward compatible loading using
the version number at `local_metadata.get("version", None)`.
.. note::
:attr:`state_dict` is not the same object as the input
:attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
it can be modified.
Args:
state_dict (dict): a dict containing parameters and
persistent buffers.
prefix (str): the prefix for parameters and buffers used in this
module
local_metadata (dict): a dict containing the metadata for this module.
See
strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` with :attr:`prefix` match the names of
parameters and buffers in this module
missing_keys (list of str): if ``strict=True``, add missing keys to
this list
unexpected_keys (list of str): if ``strict=True``, add unexpected
keys to this list
error_msgs (list of str): error messages should be added to this
list, and will be reported together in
:meth:`~torch.nn.Module.load_state_dict`
"""
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}
for name, param in local_state.items():
key = prefix + name
if key in state_dict:
input_param = state_dict[key]
if not torch.overrides.is_tensor_like(input_param):
error_msgs.append(
'While copying the parameter named "{}", '
"expected torch.Tensor or Tensor-like object from checkpoint but "
"received {}".format(key, type(input_param))
)
continue
if is_padded_tensor(param):
input_param = to_padded_tensor(input_param, param._current_length, param._padding_dim)
if is_distributed_tensor(param):
# shard the input param
device_mesh = get_device_mesh(param)
sharding_spec = get_sharding_spec(param)
sharded_tensor = distribute_tensor(input_param, device_mesh, sharding_spec)
input_param = sharded_tensor_to_param(sharded_tensor)
elif is_customized_distributed_tensor(param):
input_param = distribute_tensor_with_customization(input_param, param.shard_fn, param.gather_fn)
# This is used to avoid copying uninitialized parameters into
# non-lazy modules, since they dont have the hook to do the checks
# in such case, it will error when accessing the .shape attribute.
is_param_lazy = torch.nn.parameter.is_lazy(param)
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
input_param = input_param[0]
if not is_param_lazy and input_param.shape != param.shape:
# local shape should match the one in checkpoint
error_msgs.append(
"size mismatch for {}: copying a param with shape {} from checkpoint, "
"the shape in current model is {}.".format(key, input_param.shape, param.shape)
)
continue
try:
with torch.no_grad():
param.copy_(input_param)
except Exception as ex:
error_msgs.append(
'While copying the parameter named "{}", '
"whose dimensions in the model are {} and "
"whose dimensions in the checkpoint are {}, "
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
)
elif strict:
missing_keys.append(key)
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state:
if extra_state_key in state_dict:
self.set_extra_state(state_dict[extra_state_key])
elif strict:
missing_keys.append(extra_state_key)
elif strict and (extra_state_key in state_dict):
unexpected_keys.append(extra_state_key)
if strict:
for key in state_dict.keys():
if key.startswith(prefix) and key != extra_state_key:
input_name = key[len(prefix) :]
input_name = input_name.split(".", 1)[0] # get the name of param/buffer/child
if input_name not in self._modules and input_name not in local_state:
unexpected_keys.append(key)
def resize_embedding_weight(self):
self.weight = to_padded_tensor(self.weight, self.new_num_embeddings, 0)
def resize_embedding_bias(self):
self.bias = to_padded_tensor(self.bias, self.new_num_embeddings, 0)

View File

@ -1287,3 +1287,16 @@ def bert_sequence_parallel_forward_fn(shard_config: ShardConfig):
)
return forward
def get_jit_fused_bert_intermediate_forward():
from transformers.models.bert.modeling_bert import BertIntermediate
from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction
def forward(self: BertIntermediate, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, bias = self.dense(hidden_states)
hidden_states = JitGeLUFunction.apply(hidden_states, bias)
return hidden_states
return forward

View File

@ -129,3 +129,17 @@ def get_jit_fused_blip2_QFormer_output_forward():
return hidden_states
return forward
def get_jit_fused_blip2_mlp_forward():
from transformers.models.blip_2.modeling_blip_2 import Blip2MLP
from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction
def forward(self: Blip2MLP, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, bias = self.fc1(hidden_states)
hidden_states = JitGeLUFunction.apply(hidden_states, bias)
hidden_states = self.fc2(hidden_states)
return hidden_states
return forward

View File

@ -6,6 +6,7 @@ import torch.distributed as dist
from torch.distributed import ProcessGroup
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.nn import functional as F
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
@ -205,12 +206,13 @@ class BloomPipelineForwards:
alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
# causal_mask is constructed every stage and its input is passed through different stages
causal_mask = self._prepare_attn_mask(
causal_mask = _prepare_4d_causal_attention_mask(
attention_mask,
input_shape=(batch_size, seq_length),
inputs_embeds=hidden_states,
past_key_values_length=past_key_values_length,
)
causal_mask = causal_mask.bool()
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
if shard_config and shard_config.enable_sequence_parallelism:
@ -227,21 +229,15 @@ class BloomPipelineForwards:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
outputs = self._gradient_checkpointing_func(
block.__call__,
hidden_states,
alibi,
causal_mask,
layer_past,
head_mask[i],
use_cache,
output_attentions,
)
else:
outputs = block(
@ -1002,11 +998,13 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
causal_mask = self._prepare_attn_mask(
causal_mask = _prepare_4d_causal_attention_mask(
attention_mask,
input_shape=(batch_size, seq_length),
inputs_embeds=hidden_states,
past_key_values_length=past_key_values_length,
)
causal_mask = causal_mask.bool()
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
hidden_states = split_forward_gather_backward(
@ -1018,21 +1016,15 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
outputs = self._gradient_checkpointing_func(
block.__call__,
hidden_states,
alibi,
causal_mask,
layer_past,
head_mask[i],
use_cache,
output_attentions,
)
else:
outputs = block(

View File

@ -12,7 +12,6 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig
from colossalai.shardformer.layer import AttnMaskType, ColoAttention
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
def get_flash_core_attention_forward():
@ -31,7 +30,12 @@ def get_flash_core_attention_forward():
device=query_layer.device,
)
temp_mask = (
torch.ones(query_layer.shape[2], key_layer.shape[2], dtype=torch.bool, device=query_layer.device)
torch.ones(
query_layer.shape[2],
key_layer.shape[2],
dtype=torch.bool,
device=query_layer.device,
)
.tril(diagonal=0)
.expand(query_layer.shape[0], 1, -1, -1)
)
@ -49,6 +53,7 @@ def get_flash_core_attention_forward():
attention_mask=attn_bias,
attention_mask_type=attention_mask_type,
dropout_p=dropout_p,
scale=1.0 / self.norm_factor,
)
context_layer = context_layer.permute(2, 0, 1, 3)
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
@ -115,7 +120,7 @@ class ChatGLMPipelineForwards:
@staticmethod
def chatglm_model_forward(
self: ChatGLMModel,
self: "ChatGLMModel",
input_ids,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
@ -194,7 +199,9 @@ class ChatGLMPipelineForwards:
if shard_config and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode == "split_gather":
hidden_states = split_forward_gather_backward(
hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group
hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
)
for idx in range(start_idx, end_idx):
layer = self.encoder._get_layer(idx)
@ -224,7 +231,9 @@ class ChatGLMPipelineForwards:
if shard_config and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode == "split_gather":
hidden_states = gather_forward_split_backward(
hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group
hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
@ -254,7 +263,7 @@ class ChatGLMPipelineForwards:
@staticmethod
def chatglm_for_conditional_generation_forward(
self: ChatGLMForConditionalGeneration,
self: "ChatGLMForConditionalGeneration",
input_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,

View File

@ -1,9 +1,16 @@
import math
import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
@ -99,11 +106,17 @@ def get_tp_falcon_decoder_layer_forward():
hidden_states: torch.Tensor,
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
**kwargs,
):
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
residual = hidden_states
if self.config.new_decoder_architecture:
@ -117,10 +130,12 @@ def get_tp_falcon_decoder_layer_forward():
attention_layernorm_out,
layer_past=layer_past,
attention_mask=attention_mask,
position_ids=position_ids,
alibi=alibi,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
**kwargs,
)
attention_output = attn_outputs[0]
@ -154,87 +169,6 @@ def get_tp_falcon_decoder_layer_forward():
return forward
def get_falcon_flash_attention_forward():
try:
from xformers.ops import memory_efficient_attention as me_attention
except:
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
from transformers.models.falcon.modeling_falcon import FalconAttention
def forward(
self: FalconAttention,
hidden_states: torch.Tensor,
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
batch_size, query_length, _, _ = query_layer.shape
query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim)
key_layer = key_layer.transpose(1, 2).reshape(
batch_size * num_kv_heads,
query_length,
self.head_dim,
)
value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim)
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension:
# - key: [batch_size * self.num_heads, kv_length, head_dim]
# - value: [batch_size * self.num_heads, kv_length, head_dim]
key_layer = torch.cat((past_key, key_layer), dim=1)
value_layer = torch.cat((past_value, value_layer), dim=1)
_, kv_length, _ = key_layer.shape
if use_cache:
present = (key_layer, value_layer)
else:
present = None
attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim).transpose(1, 2).contiguous()
key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim).transpose(1, 2).contiguous()
value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim).transpose(1, 2).contiguous()
if alibi is not None:
attention_mask_float = (
attention_mask_float + alibi.view(batch_size, self.num_heads, 1, kv_length) * self.beta
)
batch_size, src_len = query_layer_.size()[0], query_layer_.size()[1]
tgt_len = key_layer_.size()[1]
attention_mask_float = attention_mask_float.expand(batch_size, self.num_heads, src_len, tgt_len).contiguous()
context_layer = me_attention(
query_layer_,
key_layer_,
value_layer_,
attn_bias=attention_mask_float,
scale=self.inv_norm_factor,
p=self.attention_dropout.p,
)
batch_size, seq_length, _, _ = context_layer.shape
context_layer = context_layer.reshape(batch_size, seq_length, -1)
output_tensor = self.dense(context_layer)
return output_tensor, present
return forward
class FalconPipelineForwards:
"""
This class serves as a micro library for falcon pipeline forwards.
@ -246,6 +180,7 @@ class FalconPipelineForwards:
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
@ -274,17 +209,6 @@ class FalconPipelineForwards:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if past_key_values is None:
past_key_values = tuple([None] * len(self.h))
else:
past_key_values = self._convert_to_rw_cache(past_key_values)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape batch_size x num_heads x N x N
# head_mask has shape n_layer x batch x num_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
# case: First stage of training
if stage_manager.is_first_stage():
if input_ids is not None and inputs_embeds is not None:
@ -295,16 +219,22 @@ class FalconPipelineForwards:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
hidden_states = inputs_embeds
else:
input_shape = hidden_states.shape[:-1]
batch_size, seq_length = input_shape
if past_key_values is None:
past_key_values = tuple([None] * len(self.h))
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
@ -312,22 +242,80 @@ class FalconPipelineForwards:
# Compute alibi tensor: check build_alibi_tensor documentation
past_key_values_length = 0
if past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[1] # 1 because RW-cache, not standard format
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=hidden_states.device)
else:
attention_mask = attention_mask.to(hidden_states.device)
past_key_values_length = past_key_values[0][0].shape[-2]
if self.use_alibi:
alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
mask = (
torch.ones(
(batch_size, seq_length + past_key_values_length), device=inputs_embeds.device, dtype=torch.long
)
if attention_mask is None
else attention_mask
)
alibi = build_alibi_tensor(mask, self.num_heads, dtype=hidden_states.dtype)
else:
alibi = None
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0)
causal_mask = self._prepare_attn_mask(
attention_mask,
input_shape=(batch_size, seq_length),
past_key_values_length=past_key_values_length,
)
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._use_sdpa and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
if alibi is None:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
elif head_mask is None:
alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:])
attention_mask_2d = attention_mask
# We don't call _prepare_4d_causal_attention_mask_for_sdpa as we need to mask alibi using the 4D attention_mask untouched.
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
# We take care to integrate alibi bias in the attention_mask here.
if attention_mask_2d is None:
attention_mask = alibi / math.sqrt(self.config.hidden_size // self.num_heads)
else:
attention_mask = torch.masked_fill(
alibi / math.sqrt(self.config.hidden_size // self.num_heads),
attention_mask < -1,
torch.finfo(alibi.dtype).min,
)
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
if seq_length > 1:
attention_mask = AttentionMaskConverter._unmask_unattended(
attention_mask, attention_mask_2d, unmasked_value=0.0
)
else:
# PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case.
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape batch_size x num_heads x N x N
# head_mask has shape n_layer x batch x num_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
start_idx, end_idx = stage_index[0], stage_index[1]
for i, (block, layer_past) in enumerate(
@ -337,31 +325,23 @@ class FalconPipelineForwards:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
outputs = self._gradient_checkpointing_func(
block.__call__,
hidden_states,
alibi,
causal_mask,
attention_mask,
position_ids,
head_mask[i],
layer_past,
use_cache,
output_attentions,
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
@ -382,9 +362,6 @@ class FalconPipelineForwards:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if presents is not None:
presents = self._convert_cache_to_standard_format(presents, batch_size)
if stage_manager.is_last_stage():
if not return_dict:
return tuple(

View File

@ -26,7 +26,6 @@ from colossalai.shardformer.layer._operation import gather_forward_split_backwar
from colossalai.shardformer.shard import ShardConfig
from ..layer import cross_entropy_1d
from ..layer._operation import gather_forward_split_backward
logger = logging.get_logger(__name__)
@ -178,11 +177,9 @@ class GPT2PipelineForwards:
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
if stage_manager.is_first_stage():
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
else:
if position_ids is None:
position_ids = torch.arange(0, input_shape[-1], dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
position_ids = position_ids.unsqueeze(0)
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
@ -240,22 +237,16 @@ class GPT2PipelineForwards:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
outputs = self._gradient_checkpointing_func(
block.__call__,
hidden_states,
None,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
use_cache,
output_attentions,
)
else:
outputs = block(
@ -397,13 +388,11 @@ class GPT2PipelineForwards:
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
)
else:
loss = loss_fct(shift_logits, shift_labels)
if not shard_config.parallel_output:
lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group)
if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
@ -1301,12 +1290,12 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
shift_labels = shift_labels.view(-1)
loss = cross_entropy_1d(
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
)
if not shard_config.parallel_output:
lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group)
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
@ -1321,3 +1310,18 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
)
return forward
def get_jit_fused_gpt2_mlp_forward():
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP
from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction
def forward(self: GPT2MLP, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
hidden_states, bias = self.c_fc(hidden_states)
hidden_states = JitGeLUFunction.apply(hidden_states, bias)
hidden_states = self.c_proj(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
return forward

View File

@ -148,11 +148,9 @@ class GPTJPipelineForwards:
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
# position id to be assigned not just for the first stage for attn input
if position_ids is not None:
position_ids = position_ids.view(-1, seq_length)
else:
if position_ids is None:
position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
position_ids = position_ids.unsqueeze(0)
if stage_manager.is_first_stage():
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
@ -201,21 +199,15 @@ class GPTJPipelineForwards:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
outputs = self._gradient_checkpointing_func(
block.__call__,
hidden_states,
None,
attention_mask,
position_ids,
head_mask[i],
use_cache,
output_attentions,
)
else:
outputs = block(
@ -627,7 +619,9 @@ def get_gptj_flash_attention_forward():
value = torch.cat((past_value, value), dim=-2)
if use_cache is True:
present = (key, value)
# Note that this cast is quite ugly, but is not implemented before ROPE as the original codebase keeps the key in float32 all along the computation.
# Reference: https://github.com/kingoflolz/mesh-transformer-jax/blob/f8315e3003033b23f21d78361b288953064e0e76/mesh_transformer/layers.py#L128
present = (key.to(hidden_states.dtype), value)
else:
present = None

View File

@ -7,6 +7,7 @@ import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.cache_utils import Cache
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
@ -16,6 +17,8 @@ from transformers.models.llama.modeling_llama import (
LlamaForCausalLM,
LlamaForSequenceClassification,
LlamaModel,
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
apply_rotary_pos_emb,
repeat_kv,
)
@ -31,13 +34,6 @@ from colossalai.shardformer.shard import ShardConfig
from ..layer import ColoAttention, cross_entropy_1d
try:
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
LATEST_VERSION = True
except ImportError:
LATEST_VERSION = False
class LlamaPipelineForwards:
"""
@ -75,13 +71,13 @@ class LlamaPipelineForwards:
# retrieve input_ids and inputs_embeds
if stage_manager.is_first_stage():
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
batch_size, seq_length = input_ids.shape[:2]
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
batch_size, seq_length, _ = inputs_embeds.shape[:2]
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
@ -111,11 +107,12 @@ class LlamaPipelineForwards:
if position_ids is None:
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
position_ids = position_ids.unsqueeze(0)
# embed positions, for the first stage, hidden_states is the input embeddings,
# for the other stages, hidden_states is the output of the previous stage
@ -123,20 +120,32 @@ class LlamaPipelineForwards:
# in this case, attention_mask is a dict rather than a tensor
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
attention_mask = ColoAttention.prepare_attn_kwargs(
mask_shape, hidden_states.dtype, hidden_states.device, q_padding_mask=attention_mask, is_causal=True
mask_shape,
hidden_states.dtype,
hidden_states.device,
q_padding_mask=attention_mask,
is_causal=True,
)
else:
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
)
if LATEST_VERSION:
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._use_sdpa and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
else:
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
hidden_states,
past_key_values_length,
)
if self.gradient_checkpointing and self.training:
@ -149,7 +158,7 @@ class LlamaPipelineForwards:
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
next_decoder_cache = None
start_idx, end_idx = stage_index[0], stage_index[1]
num_ckpt_layers = 0
@ -159,8 +168,10 @@ class LlamaPipelineForwards:
if shard_config.gradient_checkpoint_config is not None:
num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
stage=stage_manager.stage,
num_stages=stage_manager.num_stages,
num_layers=end_idx - start_idx,
model_chunk_id=stage_manager.model_chunk_id if stage_manager.is_interleave else 0,
model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0),
num_model_chunks=stage_manager.num_model_chunks,
)
assert num_ckpt_layers <= end_idx - start_idx
@ -168,30 +179,22 @@ class LlamaPipelineForwards:
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if idx - start_idx < num_ckpt_layers:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
None,
past_key_values,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
)
@ -199,7 +202,7 @@ class LlamaPipelineForwards:
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
@ -212,7 +215,16 @@ class LlamaPipelineForwards:
next_cache = next_decoder_cache if use_cache else None
if stage_manager.is_last_stage():
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return tuple(
v
for v in [
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
@ -316,7 +328,10 @@ class LlamaPipelineForwards:
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
)
else:
shift_logits = shift_logits.view(-1, self.config.vocab_size)
@ -455,23 +470,25 @@ class LlamaPipelineForwards:
def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size):
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
llama_version = 2
try:
from transformers.models.llama.modeling_llama import repeat_kv
except:
warnings.warn("using llamav1, llamav1 hasn't repeat_kv function")
llama_version = 1
def forward(
self: LlamaAttention,
hidden_states: torch.Tensor,
attention_mask: Optional[dict] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
bsz, q_len, _ = hidden_states.size()
if sp_mode in ["split_gather", "ring"]:
@ -495,21 +512,23 @@ def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size):
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
if llama_version == 2:
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
@ -570,7 +589,10 @@ def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig):
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
@ -584,7 +606,11 @@ def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig):
# in this case, attention_mask is a dict rather than a tensor
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
attention_mask = ColoAttention.prepare_attn_kwargs(
mask_shape, hidden_states.dtype, hidden_states.device, q_padding_mask=attention_mask, is_causal=True
mask_shape,
hidden_states.dtype,
hidden_states.device,
q_padding_mask=attention_mask,
is_causal=True,
)
if self.gradient_checkpointing and self.training:
@ -735,11 +761,13 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
)
if not return_dict:
@ -913,7 +941,10 @@ def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
@ -929,10 +960,12 @@ def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
(batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device,
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, attention_mask.shape, inputs_embeds, past_key_values_length
)

View File

@ -1,70 +1,608 @@
from typing import Optional, Tuple
import warnings
from typing import List, Optional, Tuple, Union
import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
)
from transformers.models.mistral.modeling_mistral import MistralForCausalLM, MistralModel
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.shard import ShardConfig
from ..layer import ColoAttention
logger = logging.get_logger(__name__)
def get_mistral_flash_attention_forward():
class MistralForwards:
@staticmethod
def mistral_model_forward(
self: MistralModel,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
if use_cache:
logger.warning_once("use_cache=True is not supported for Mistral models at the moment.")
use_cache = False
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if stage_manager.is_first_stage():
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
else:
input_shape = hidden_states.shape[:-1]
batch_size, seq_length = input_shape
device = hidden_states.device
past_key_values_length = 0
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if attention_mask is not None and self._use_flash_attention_2 and use_cache:
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
if is_padding_right:
raise ValueError(
"You are attempting to perform batched generation with padding_side='right'"
" this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
)
if shard_config.enable_flash_attention:
# in this case, attention_mask is a dict rather than a tensor
mask_shape = (batch_size, 1, seq_length, seq_length)
attention_mask = ColoAttention.prepare_attn_kwargs(
mask_shape,
hidden_states.dtype,
hidden_states.device,
q_padding_mask=attention_mask,
is_causal=True,
)
else:
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
hidden_states,
past_key_values_length,
sliding_window=self.config.sliding_window,
)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
start_idx, end_idx = stage_index[0], stage_index[1]
num_ckpt_layers = 0
if self.gradient_checkpointing and self.training:
num_ckpt_layers = end_idx - start_idx
# TODO: We can replace `gradient_checkpointing_enable` fn and initialize a gradient_checkpointing (List[bool]) for each layer
if shard_config.gradient_checkpoint_config is not None:
num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
stage=stage_manager.stage,
num_stages=stage_manager.num_stages,
num_layers=end_idx - start_idx,
model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0),
num_model_chunks=stage_manager.num_model_chunks,
)
assert num_ckpt_layers <= end_idx - start_idx
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if idx - start_idx < num_ckpt_layers:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
if stage_manager.is_last_stage():
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = None
if stage_manager.is_last_stage():
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
else:
return {"hidden_states": hidden_states}
@staticmethod
def mistral_for_causal_lm_forward(
self: MistralForCausalLM,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, MistralForCausalLM
>>> model = MistralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = MistralForwards.mistral_model_forward(
self.model,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
shard_config=shard_config,
)
past_key_values = None
if stage_manager.is_last_stage():
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
else:
hidden_states = outputs.get("hidden_states")
return {"hidden_states": hidden_states}
@staticmethod
def mistral_for_sequence_classification_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = MistralForwards.mistral_model_forward(
self.model,
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
shard_config=shard_config,
)
if input_ids is not None:
batch_size = input_ids.shape[0]
elif inputs_embeds is not None:
batch_size = inputs_embeds.shape[0]
else:
batch_size = hidden_states.shape[0]
if stage_manager.is_last_stage():
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
logits.device
)
else:
sequence_lengths = -1
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
else:
hidden_states = transformer_outputs.get("hidden_states")
return {"hidden_states": hidden_states}
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
def get_mistral_model_forward_for_flash_attn(shard_config: ShardConfig):
logger = logging.get_logger(__name__)
assert shard_config.enable_flash_attention, "Flash Attention is not enabled."
def forward(
self: MistralModel,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
past_key_values_length = 0
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length(seq_length)
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if attention_mask is not None and self._use_flash_attention_2 and use_cache:
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
if is_padding_right:
raise ValueError(
"You are attempting to perform batched generation with padding_side='right'"
" this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
)
if shard_config.enable_flash_attention:
# in this case, attention_mask is a dict rather than a tensor
mask_shape = (batch_size, 1, seq_length, seq_length)
attention_mask = ColoAttention.prepare_attn_kwargs(
mask_shape,
inputs_embeds.dtype,
inputs_embeds.device,
q_padding_mask=attention_mask,
is_causal=True,
)
else:
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window=self.config.sliding_window,
)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
return forward
def get_mistral_flash_attention_forward(shard_config: ShardConfig):
from transformers.models.mistral.modeling_mistral import MistralAttention, apply_rotary_pos_emb, repeat_kv
from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
def forward(
self: MistralAttention,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
bsz, q_len, _ = hidden_states.size()
assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = (
self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
)
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
me_input_shape = (bsz, q_len, self.num_heads, self.head_dim)
query_states = query_states.transpose(1, 2).contiguous().view(*me_input_shape)
key_states = key_states.transpose(1, 2).contiguous().view(*me_input_shape)
value_states = value_states.transpose(1, 2).contiguous().view(*me_input_shape)
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
flash_attention_mask = None
attn_mask_type = AttnMaskType.causal
if attention_mask != None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
attn_mask_type = AttnMaskType.paddedcausal
attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
attn_output = attention(
query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)

View File

@ -3,6 +3,7 @@ from typing import List, Optional, Tuple, Union
import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
@ -42,7 +43,7 @@ def _get_attention_mask(
is_causal=True,
)
else:
attention_mask = self.decoder._prepare_decoder_attention_mask(
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
hidden_states,
@ -112,7 +113,7 @@ class OPTPipelineForwards:
inputs_embeds = decoder.project_in(inputs_embeds)
device = input_ids.device if input_ids is not None else inputs_embeds.device
inputs_embeds.dtype
hidden_states = inputs_embeds
else:
if hidden_states is None:
raise ValueError("hidden_states shouldn't be None for intermediate stages.")
@ -125,12 +126,25 @@ class OPTPipelineForwards:
# required mask seq length can be calculated via length of past
mask_seq_length = past_key_values_length + seq_length
# embed positions
if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=device)
elif attention_mask.shape[1] != mask_seq_length:
raise ValueError(
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
f"{mask_seq_length} (sum of the lengths of current and past inputs)"
if self.decoder._use_flash_attention_2:
# 2d mask is passed through the layers
causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
attention_mask = (
torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
if attention_mask is None
else attention_mask
)
else:
# 4d mask is passed through the layers
if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
elif attention_mask.shape[1] != mask_seq_length:
raise ValueError(
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
f"{mask_seq_length} (sum of the lengths of current and past inputs)"
)
causal_attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, hidden_states, past_key_values_length
)
if stage_manager.is_first_stage():
@ -205,20 +219,14 @@ class OPTPipelineForwards:
past_key_value = past_key_values[idx] if past_key_values is not None else None
if decoder.gradient_checkpointing and decoder.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_attention_mask,
head_mask[idx] if head_mask is not None else None,
None,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(

Some files were not shown because too many files have changed in this diff Show More