mirror of https://github.com/hpcaitech/ColossalAI
[Sync] Update from main to feature/colossal-infer (Merge pull request #5685)
[Sync] Update from main to feature/colossal-infer - Merge pull request #5685 from yuanheng-zhao/inference/merge/mainpull/5695/head
commit
db7b3051f4
|
@ -56,7 +56,7 @@ jobs:
|
|||
needs: detect-changed-doc
|
||||
runs-on: [self-hosted, gpu]
|
||||
container:
|
||||
image: hpcaitech/pytorch-cuda:2.0.0-11.7.0
|
||||
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
|
||||
options: --gpus all --rm
|
||||
timeout-minutes: 20
|
||||
defaults:
|
||||
|
|
|
@ -24,7 +24,7 @@ jobs:
|
|||
version=$(cat version.txt)
|
||||
tag=hpcaitech/colossalai:$version
|
||||
latest=hpcaitech/colossalai:latest
|
||||
docker build --build-arg http_proxy=http://172.17.0.1:7890 --build-arg https_proxy=http://172.17.0.1:7890 --build-arg VERSION=v${version} -t $tag ./docker
|
||||
docker build --build-arg VERSION=v${version} -t $tag ./docker
|
||||
docker tag $tag $latest
|
||||
echo "tag=${tag}" >> $GITHUB_OUTPUT
|
||||
echo "latest=${latest}" >> $GITHUB_OUTPUT
|
||||
|
|
15
LICENSE
15
LICENSE
|
@ -552,3 +552,18 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved.
|
|||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
||||
---------------- LICENSE FOR Hugging Face accelerate ----------------
|
||||
|
||||
Copyright 2021 The HuggingFace Team
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
|
18
README.md
18
README.md
|
@ -25,6 +25,8 @@
|
|||
</div>
|
||||
|
||||
## Latest News
|
||||
* [2024/04] [Open-Sora Unveils Major Upgrade: Embracing Open Source with Single-Shot 16-Second Video Generation and 720p Resolution](https://hpc-ai.com/blog/open-soras-comprehensive-upgrade-unveiled-embracing-16-second-video-generation-and-720p-resolution-in-open-source)
|
||||
* [2024/04] [Most cost-effective solutions for inference, fine-tuning and pretraining, tailored to LLaMA3 series](https://hpc-ai.com/blog/most-cost-effective-solutions-for-inference-fine-tuning-and-pretraining-tailored-to-llama3-series)
|
||||
* [2024/03] [314 Billion Parameter Grok-1 Inference Accelerated by 3.8x, Efficient and Easy-to-Use PyTorch+HuggingFace version is Here](https://hpc-ai.com/blog/314-billion-parameter-grok-1-inference-accelerated-by-3.8x-efficient-and-easy-to-use-pytorchhuggingface-version-is-here)
|
||||
* [2024/03] [Open-Sora: Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models](https://hpc-ai.com/blog/open-sora-v1.0)
|
||||
* [2024/03] [Open-Sora:Sora Replication Solution with 46% Cost Reduction, Sequence Expansion to Nearly a Million](https://hpc-ai.com/blog/open-sora)
|
||||
|
@ -52,7 +54,7 @@
|
|||
<li>
|
||||
<a href="#Parallel-Training-Demo">Parallel Training Demo</a>
|
||||
<ul>
|
||||
<li><a href="#LLaMA2">LLaMA 1/2</a></li>
|
||||
<li><a href="#LLaMA3">LLaMA 1/2/3 </a></li>
|
||||
<li><a href="#MoE">MoE</a></li>
|
||||
<li><a href="#GPT-3">GPT-3</a></li>
|
||||
<li><a href="#GPT-2">GPT-2</a></li>
|
||||
|
@ -131,7 +133,7 @@ distributed training and inference in a few lines.
|
|||
|
||||
[Open-Sora](https://github.com/hpcaitech/Open-Sora):Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models
|
||||
[[code]](https://github.com/hpcaitech/Open-Sora)
|
||||
[[blog]](https://hpc-ai.com/blog/open-sora-v1.0)
|
||||
[[blog]](https://hpc-ai.com/blog/open-soras-comprehensive-upgrade-unveiled-embracing-16-second-video-generation-and-720p-resolution-in-open-source)
|
||||
[[HuggingFace model weights]](https://huggingface.co/hpcai-tech/Open-Sora)
|
||||
[[Demo]](https://github.com/hpcaitech/Open-Sora?tab=readme-ov-file#-latest-demo)
|
||||
|
||||
|
@ -270,13 +272,21 @@ Acceleration of [AlphaFold Protein Structure](https://alphafold.ebi.ac.uk/)
|
|||
<p align="right">(<a href="#top">back to top</a>)</p>
|
||||
|
||||
## Parallel Training Demo
|
||||
### LLaMA3
|
||||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/examples/images/LLaMA3-70B-H100.png" width=600/>
|
||||
</p>
|
||||
|
||||
- 70 billion parameter LLaMA3 model training accelerated by 18%
|
||||
[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama)
|
||||
|
||||
### LLaMA2
|
||||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/llama2_pretraining.png" width=600/>
|
||||
</p>
|
||||
|
||||
- 70 billion parameter LLaMA2 model training accelerated by 195%
|
||||
[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama2)
|
||||
[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama)
|
||||
[[blog]](https://www.hpc-ai.tech/blog/70b-llama2-training)
|
||||
|
||||
### LLaMA1
|
||||
|
@ -285,7 +295,7 @@ Acceleration of [AlphaFold Protein Structure](https://alphafold.ebi.ac.uk/)
|
|||
</p>
|
||||
|
||||
- 65-billion-parameter large model pretraining accelerated by 38%
|
||||
[[code]](https://github.com/hpcaitech/ColossalAI/tree/example/llama/examples/language/llama)
|
||||
[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama)
|
||||
[[blog]](https://www.hpc-ai.tech/blog/large-model-pretraining)
|
||||
|
||||
### MoE
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
0.0.1
|
|
@ -1,6 +1,6 @@
|
|||
<div align="center">
|
||||
<h1>
|
||||
<img src="https://github.com/hpcaitech/public_assets/blob/main/applications/colossal-llama-2/colossalllam2.jpg?raw=true" width=800/>
|
||||
Colossal-LLaMA
|
||||
</h1>
|
||||
</div>
|
||||
|
||||
|
@ -47,6 +47,7 @@
|
|||
- [Citations](#citations)
|
||||
|
||||
## News
|
||||
* [2024/4] Support continual pre-training and supervised fine-tuning of LLaMA-3.
|
||||
* [2024/01] [Construct Refined 13B Private Model With Just $5000 USD, Upgraded Colossal-AI Llama-2 Open Source](https://hpc-ai.com/blog/colossal-llama-2-13b).
|
||||
[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Colossal-LLaMA-2)
|
||||
[[blog]](https://hpc-ai.com/blog/colossal-llama-2-13b)
|
||||
|
@ -289,7 +290,7 @@ Here is details about CLI arguments:
|
|||
|
||||
#### 1. Install required packages
|
||||
```
|
||||
cd Colossal-LLaMA-2
|
||||
cd Colossal-LLaMA
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
#### 2. Install `xentropy`, `layer_norm` and `rotary`
|
||||
|
@ -314,7 +315,7 @@ Initialize new tokenizer with additional Chinese tokens. Additional Chinese toke
|
|||
Command to initialize new tokenizer:
|
||||
```bash
|
||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION='python'
|
||||
python colossal_llama2/tokenizer/init_tokenizer.py \
|
||||
python colossal_llama/tokenizer/init_tokenizer.py \
|
||||
--source_tokenizer_dir "<SOURCE_TOKENIZER_DIR>" \
|
||||
--target_tokenizer_dir "<TARGET_TOKENIZER_DIR>" \
|
||||
--expand_tokens_file "<NEW_TOKENS_FILE>.jsonl"
|
||||
|
@ -328,7 +329,7 @@ Here is details about CLI arguments:
|
|||
Initialize the new model checkpoint by calculating the mean values from the original model checkpoint.
|
||||
Command to initialize new model checkpoint:
|
||||
```bash
|
||||
python colossal_llama2/model/init_model.py \
|
||||
python colossal_llama/model/init_model.py \
|
||||
--source_model_and_tokenizer_path "<SOURCE_MODEL_AND_TOKENIZER_DIR>" \
|
||||
--target_tokenizer_path "<TARGET_TOKENIZER_DIR>" \
|
||||
--target_model_path "<TARGET_MODEL_DIR>"
|
||||
|
@ -362,18 +363,17 @@ Command to convert jsonl dataset to arrow format:
|
|||
python prepare_pretrain_dataset.py \
|
||||
--data_input_dirs "<JSONL_DIR_1>,<JSONL_DIR_2>,<JSONL_DIR_3>" \
|
||||
--tokenizer_dir "<TOKENIZER_DIR>" \
|
||||
--data_cache_dir "jsonl_to_arrow_cache" \
|
||||
--data_jsonl_output_dir "spliced_tokenized_output_jsonl" \
|
||||
--data_arrow_output_dir "spliced_tokenized_output_arrow" \
|
||||
--data_output_dirs "spliced tokenized output" \
|
||||
--max_length 4096 \
|
||||
--num_spliced_dataset_bins 10
|
||||
```
|
||||
Here is details about CLI arguments:
|
||||
* Source data directory: `data_input_dirs`. Each `<JSONL_DIR>` can have multiple file in `jsonl` format.
|
||||
* Tokenizer directory: `tokenizer_dir`. Path to the tokenizer in Hugging Face format.
|
||||
* Data cache directory: `data_cache_dir`. Directory to store Hugging Face data cache. Default case will create `cache` folder locally.
|
||||
* Output directory for jsonl format: `data_jsonl_output_dir`. Output directory to store converted dataset in jsonl format.
|
||||
* Output directory for arrow format: `data_arrow_output_dir`. Output directory to store converted dataset in arrow format, which can be used for training directly.
|
||||
* Data output directory: `data_output_dirs`. Directory to store preprocessed output, including three sub-directories:
|
||||
* `cache`: Directory to store Hugging Face data cache.
|
||||
* `jsonl`: Output directory to store converted dataset in jsonl format.
|
||||
* `arrow`: Output directory to store converted dataset in arrow format, which can be used for training directly.
|
||||
* Max length: `max_length`. Max length of spliced samples. Default value is 4096.
|
||||
* Number of bins for each category: `num_spliced_dataset_bins`. Number of bins for each category, used for bucket-based training.
|
||||
|
||||
|
@ -392,13 +392,15 @@ Command to convert jsonl dataset to arrow format is similar to the command in [3
|
|||
python prepare_sft_dataset.py.py \
|
||||
--data_input_dirs "<JSONL_DIR_1>,<JSONL_DIR_2>,<JSONL_DIR_3>" \
|
||||
--tokenizer_dir "<TOKENIZER_DIR>" \
|
||||
--data_cache_dir "jsonl_to_arrow_cache" \
|
||||
--data_jsonl_output_dir "spliced_tokenized_output_jsonl" \
|
||||
--data_arrow_output_dir "spliced_tokenized_output_arrow" \
|
||||
--data_output_dirs "spliced tokenized output" \
|
||||
--max_length 4096 \
|
||||
--num_spliced_dataset_bins 10
|
||||
--num_spliced_dataset_bins 10 \
|
||||
--llama_version 3
|
||||
```
|
||||
|
||||
Additional CLI arguments:
|
||||
* LLaMA verison: `llama_version`. Specify the LLaMA version.
|
||||
|
||||
#### 4. Command Line Arguments for Training
|
||||
|
||||
##### 4.1 Arguments for Pretraining
|
|
@ -83,7 +83,7 @@ class Conversation:
|
|||
}
|
||||
|
||||
|
||||
conv = Conversation(
|
||||
LLaMA2_Conv = Conversation(
|
||||
system="A chat between a curious human and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
|
||||
roles=("Human", "Assistant"),
|
||||
|
@ -93,4 +93,14 @@ conv = Conversation(
|
|||
seps=["<s>", "</s>"],
|
||||
)
|
||||
|
||||
default_conversation = conv
|
||||
LLaMA3_Conv = Conversation(
|
||||
system="A chat between a curious human and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
|
||||
roles=("Human", "Assistant"),
|
||||
messages=[],
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.ADD_BOS_EOS_TOKEN,
|
||||
seps=["<|begin_of_text|>", "<|end_of_text|>"],
|
||||
)
|
||||
|
||||
default_conversation = LLaMA3_Conv
|
|
@ -80,15 +80,19 @@ class DataCollatorForSupervisedDataset(object):
|
|||
|
||||
# `List[torch.Tensor]`
|
||||
batch_input_ids = [
|
||||
torch.LongTensor(instance["input_ids"][: self.max_length])
|
||||
if len(instance["input_ids"]) > self.max_length
|
||||
else torch.LongTensor(instance["input_ids"])
|
||||
(
|
||||
torch.LongTensor(instance["input_ids"][: self.max_length])
|
||||
if len(instance["input_ids"]) > self.max_length
|
||||
else torch.LongTensor(instance["input_ids"])
|
||||
)
|
||||
for instance in instances
|
||||
]
|
||||
batch_labels = [
|
||||
torch.LongTensor(instance["labels"][: self.max_length])
|
||||
if len(instance["labels"]) > self.max_length
|
||||
else torch.LongTensor(instance["labels"])
|
||||
(
|
||||
torch.LongTensor(instance["labels"][: self.max_length])
|
||||
if len(instance["labels"]) > self.max_length
|
||||
else torch.LongTensor(instance["labels"])
|
||||
)
|
||||
for instance in instances
|
||||
]
|
||||
|
|
@ -12,6 +12,7 @@ from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
|
|||
|
||||
from datasets import dataset_dict
|
||||
from torch.utils.data import ConcatDataset, Dataset, IterableDataset
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.llama.tokenization_llama import LlamaTokenizer
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
|
@ -71,7 +72,7 @@ def supervised_tokenize_pretrain(
|
|||
|
||||
def supervised_tokenize_sft(
|
||||
data_point: Dict[str, str],
|
||||
tokenizer: LlamaTokenizer,
|
||||
tokenizer: AutoTokenizer,
|
||||
conversation_template: Conversation = default_conversation,
|
||||
ignore_index: int = None,
|
||||
max_length: int = 4096,
|
|
@ -1,7 +1,7 @@
|
|||
import argparse
|
||||
|
||||
import torch
|
||||
from colossal_llama2.dataset.conversation import default_conversation
|
||||
from colossal_llama.dataset.conversation import default_conversation
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
|
@ -11,12 +11,12 @@ import os
|
|||
import time
|
||||
from multiprocessing import cpu_count
|
||||
|
||||
from colossal_llama2.dataset.spliced_and_tokenized_dataset import (
|
||||
from colossal_llama.dataset.spliced_and_tokenized_dataset import (
|
||||
ClosedToConstantLengthSplicedDataset,
|
||||
supervised_tokenize_pretrain,
|
||||
)
|
||||
from datasets import dataset_dict, load_dataset
|
||||
from transformers.models.llama.tokenization_llama import LlamaTokenizer
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
|
@ -35,35 +35,24 @@ def main():
|
|||
parser.add_argument(
|
||||
"--tokenizer_dir", type=str, required=True, default=None, help="A directory containing the tokenizer"
|
||||
)
|
||||
parser.add_argument("--data_cache_dir", type=str, default="cache", help="Data cache directory")
|
||||
parser.add_argument(
|
||||
"--data_jsonl_output_dir",
|
||||
type=str,
|
||||
default="jsonl_output",
|
||||
help="Output directory of spliced dataset with jsonl format",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_arrow_output_dir",
|
||||
type=str,
|
||||
default="arrow_output",
|
||||
help="Output directory of spliced dataset with arrow format",
|
||||
)
|
||||
parser.add_argument("--max_length", type=int, default=4096, help="Max length of each spliced tokenized sequence")
|
||||
parser.add_argument("--data_output_dirs", type=str, default="data_output_dirs", help="Data output directory")
|
||||
parser.add_argument("--max_length", type=int, default=8192, help="Max length of each spliced tokenized sequence")
|
||||
parser.add_argument("--num_spliced_dataset_bins", type=int, default=10, help="Number of spliced dataset bins")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.num_spliced_dataset_bins >= 100000:
|
||||
raise ValueError("Too many spliced divisions, must be smaller than 100000")
|
||||
|
||||
assert not os.path.exists(args.data_cache_dir), f"Find existed data cache dir {args.data_cache_dir}"
|
||||
assert not os.path.exists(
|
||||
args.data_jsonl_output_dir
|
||||
), f"Find existed jsonl data output dir {args.data_jsonl_output_dir}"
|
||||
assert not os.path.exists(
|
||||
args.data_arrow_output_dir
|
||||
), f"Find existed arrow data output dir {args.data_arrow_output_dir}"
|
||||
os.makedirs(args.data_jsonl_output_dir)
|
||||
os.makedirs(args.data_arrow_output_dir)
|
||||
args.data_cache_dir = os.path.join(args.data_output_dirs, "cache")
|
||||
args.data_jsonl_output_dir = os.path.join(args.data_output_dirs, "jsonl")
|
||||
args.data_arrow_output_dir = os.path.join(args.data_output_dirs, "arrow")
|
||||
|
||||
if not os.path.exists(args.data_cache_dir):
|
||||
os.makedirs(args.data_cache_dir)
|
||||
if not os.path.exists(args.data_jsonl_output_dir):
|
||||
os.makedirs(args.data_jsonl_output_dir)
|
||||
if not os.path.exists(args.data_arrow_output_dir):
|
||||
os.makedirs(args.data_arrow_output_dir)
|
||||
|
||||
# Prepare to all input datasets
|
||||
input_data_paths = []
|
||||
|
@ -86,7 +75,7 @@ def main():
|
|||
train_splits.append(f"train[{start}%:{end}%]")
|
||||
|
||||
# Prepare to the tokenizer.
|
||||
tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_dir)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir)
|
||||
tokenizer.add_bos_token = False
|
||||
tokenizer.add_eos_token = False
|
||||
if tokenizer.pad_token is None:
|
|
@ -10,10 +10,10 @@ import math
|
|||
import os
|
||||
from multiprocessing import cpu_count
|
||||
|
||||
from colossal_llama2.dataset.conversation import default_conversation
|
||||
from colossal_llama2.dataset.spliced_and_tokenized_dataset import supervised_tokenize_sft
|
||||
from colossal_llama.dataset.conversation import default_conversation
|
||||
from colossal_llama.dataset.spliced_and_tokenized_dataset import supervised_tokenize_sft
|
||||
from datasets import dataset_dict, load_dataset
|
||||
from transformers.models.llama.tokenization_llama import LlamaTokenizer
|
||||
from transformers import AddedToken, AutoTokenizer
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
|
@ -32,35 +32,25 @@ def main():
|
|||
parser.add_argument(
|
||||
"--tokenizer_dir", type=str, required=True, default=None, help="A directory containing the tokenizer"
|
||||
)
|
||||
parser.add_argument("--data_cache_dir", type=str, default="cache", help="Data cache directory")
|
||||
parser.add_argument(
|
||||
"--data_jsonl_output_dir",
|
||||
type=str,
|
||||
default="jsonl_output",
|
||||
help="Output directory of spliced dataset with jsonl format",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_arrow_output_dir",
|
||||
type=str,
|
||||
default="arrow_output",
|
||||
help="Output directory of spliced dataset with arrow format",
|
||||
)
|
||||
parser.add_argument("--max_length", type=int, default=4096, help="Max length of each spliced tokenized sequence")
|
||||
parser.add_argument("--data_output_dirs", type=str, default="data_output_dirs", help="Data output directory")
|
||||
parser.add_argument("--max_length", type=int, default=8192, help="Max length of each spliced tokenized sequence")
|
||||
parser.add_argument("--num_spliced_dataset_bins", type=int, default=10, help="Number of spliced dataset bins")
|
||||
parser.add_argument("--llama_version", type=int, default=3, help="LLaMA version")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.num_spliced_dataset_bins >= 100000:
|
||||
raise ValueError("Too many spliced divisions, must be smaller than 100000")
|
||||
|
||||
assert not os.path.exists(args.data_cache_dir), f"Find existed data cache dir {args.data_cache_dir}"
|
||||
assert not os.path.exists(
|
||||
args.data_jsonl_output_dir
|
||||
), f"Find existed jsonl data output dir {args.data_jsonl_output_dir}"
|
||||
assert not os.path.exists(
|
||||
args.data_arrow_output_dir
|
||||
), f"Find existed arrow data output dir {args.data_arrow_output_dir}"
|
||||
os.makedirs(args.data_jsonl_output_dir)
|
||||
os.makedirs(args.data_arrow_output_dir)
|
||||
args.data_cache_dir = os.path.join(args.data_output_dirs, "cache")
|
||||
args.data_jsonl_output_dir = os.path.join(args.data_output_dirs, "jsonl")
|
||||
args.data_arrow_output_dir = os.path.join(args.data_output_dirs, "arrow")
|
||||
|
||||
if not os.path.exists(args.data_cache_dir):
|
||||
os.makedirs(args.data_cache_dir)
|
||||
if not os.path.exists(args.data_jsonl_output_dir):
|
||||
os.makedirs(args.data_jsonl_output_dir)
|
||||
if not os.path.exists(args.data_arrow_output_dir):
|
||||
os.makedirs(args.data_arrow_output_dir)
|
||||
|
||||
# Prepare to all input datasets
|
||||
input_data_paths = []
|
||||
|
@ -83,11 +73,20 @@ def main():
|
|||
train_splits.append(f"train[{start}%:{end}%]")
|
||||
|
||||
# Prepare to the tokenizer.
|
||||
tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_dir)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir)
|
||||
|
||||
# Fix </s> split issue: https://github.com/huggingface/transformers/issues/23833
|
||||
if args.llama_version == 2:
|
||||
tokenizer.add_tokens(AddedToken("</s>", normalized=False, special=True), special_tokens=True)
|
||||
|
||||
tokenizer.add_bos_token = False
|
||||
tokenizer.add_eos_token = False
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
if tokenizer.unk_token is not None:
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
else:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer.unk_token = tokenizer.eos_token
|
||||
|
||||
list_dataset = load_dataset(
|
||||
path="json",
|
|
@ -1,9 +1,10 @@
|
|||
torch<2.0.0, >=1.12.1
|
||||
packaging==23.1
|
||||
colossalai==0.3.5
|
||||
torch==2.1.2
|
||||
huggingface-hub
|
||||
packaging==24.0
|
||||
colossalai==0.3.6
|
||||
autoflake==2.2.1
|
||||
black==23.9.1
|
||||
transformers==4.33.3
|
||||
transformers==4.34.1
|
||||
tensorboard==2.14.0
|
||||
six==1.16.0
|
||||
datasets
|
|
@ -1,6 +1,6 @@
|
|||
import argparse
|
||||
|
||||
from colossal_llama2.utils.stream_chat_patch import streaming_chat
|
||||
from colossal_llama.utils.stream_chat_patch import streaming_chat
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
SYSTEM = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
|
|
@ -12,18 +12,18 @@ from contextlib import nullcontext
|
|||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossal_llama2.dataset.loader import (
|
||||
from colossal_llama.dataset.loader import (
|
||||
DataCollatorForSupervisedDataset,
|
||||
StatefulDistributedSampler,
|
||||
load_tokenized_dataset,
|
||||
)
|
||||
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
|
||||
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
|
||||
from colossal_llama2.utils.froze import freeze_non_embeds_parameters
|
||||
from colossal_llama2.utils.neftune_patch import activate_neftune, deactivate_neftune
|
||||
from colossal_llama.utils.ckpt_io import load_checkpoint, save_checkpoint
|
||||
from colossal_llama.utils.flash_attention_patch import replace_with_flash_attention
|
||||
from colossal_llama.utils.froze import freeze_non_embeds_parameters
|
||||
from colossal_llama.utils.neftune_patch import activate_neftune, deactivate_neftune
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tqdm import tqdm
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
from transformers import AutoTokenizer, LlamaForCausalLM
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
|
@ -89,7 +89,7 @@ def main() -> None:
|
|||
parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of accumulation steps")
|
||||
parser.add_argument("--micro_batch_size", type=int, default=2, help="Batch size of each process")
|
||||
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
|
||||
parser.add_argument("--max_length", type=int, default=4096, help="Model max length")
|
||||
parser.add_argument("--max_length", type=int, default=8192, help="Model max length")
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
|
@ -136,7 +136,7 @@ def main() -> None:
|
|||
# ==============================
|
||||
# Initialize Distributed Training
|
||||
# ==============================
|
||||
colossalai.launch_from_torch({})
|
||||
colossalai.launch_from_torch()
|
||||
accelerator = get_accelerator()
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
|
@ -196,7 +196,7 @@ def main() -> None:
|
|||
# ======================================================
|
||||
# Initialize Tokenizer, Dataset, Collator and Dataloader
|
||||
# ======================================================
|
||||
tokenizer = LlamaTokenizer.from_pretrained(args.pretrained)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
|
||||
if args.pad_token == "eos":
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.pad_token == "unk":
|
||||
|
@ -253,9 +253,11 @@ def main() -> None:
|
|||
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
|
||||
|
||||
optimizer = HybridAdam(
|
||||
model_params=filter(lambda p: p.requires_grad, model.parameters())
|
||||
if args.freeze_non_embeds_params
|
||||
else model.parameters(),
|
||||
model_params=(
|
||||
filter(lambda p: p.requires_grad, model.parameters())
|
||||
if args.freeze_non_embeds_params
|
||||
else model.parameters()
|
||||
),
|
||||
lr=args.lr,
|
||||
betas=(0.9, 0.95),
|
||||
weight_decay=args.weight_decay,
|
|
@ -0,0 +1 @@
|
|||
1.0.0
|
|
@ -66,7 +66,7 @@ def benchmark_train(args):
|
|||
# ==============================
|
||||
# Initialize Distributed Training
|
||||
# ==============================
|
||||
colossalai.launch_from_torch({})
|
||||
colossalai.launch_from_torch()
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# ======================================================
|
||||
|
|
|
@ -37,7 +37,7 @@ def train(args):
|
|||
# ==============================
|
||||
# Initialize Distributed Training
|
||||
# ==============================
|
||||
colossalai.launch_from_torch({})
|
||||
colossalai.launch_from_torch()
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# ==============================
|
||||
|
|
|
@ -39,7 +39,7 @@ def train(args):
|
|||
# ==============================
|
||||
# Initialize Distributed Training
|
||||
# ==============================
|
||||
colossalai.launch_from_torch({})
|
||||
colossalai.launch_from_torch()
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# ======================================================
|
||||
|
|
|
@ -34,7 +34,7 @@ def train(args):
|
|||
# ==============================
|
||||
# Initialize Distributed Training
|
||||
# ==============================
|
||||
colossalai.launch_from_torch({})
|
||||
colossalai.launch_from_torch()
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# ======================================================
|
||||
|
|
|
@ -29,7 +29,7 @@ def train(args):
|
|||
# ==============================
|
||||
# Initialize Distributed Training
|
||||
# ==============================
|
||||
colossalai.launch_from_torch({})
|
||||
colossalai.launch_from_torch()
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# ==============================
|
||||
|
|
|
@ -81,7 +81,7 @@ def rm_and_merge(
|
|||
|
||||
|
||||
def main(args):
|
||||
colossalai.launch_from_torch(config={}, seed=42)
|
||||
colossalai.launch_from_torch(seed=42)
|
||||
accelerator = get_accelerator()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
|
|
|
@ -81,7 +81,7 @@ def rm_and_merge(
|
|||
|
||||
|
||||
def main(args):
|
||||
colossalai.launch_from_torch(config={}, seed=42)
|
||||
colossalai.launch_from_torch(seed=42)
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
rank = dist.get_rank()
|
||||
|
|
Binary file not shown.
|
@ -57,7 +57,7 @@ def main():
|
|||
args = parse_args()
|
||||
|
||||
# Launch ColossalAI
|
||||
colossalai.launch_from_torch(config={}, seed=args.seed)
|
||||
colossalai.launch_from_torch(seed=args.seed)
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
config = MixtralConfig.from_pretrained(args.model_name)
|
||||
|
@ -96,7 +96,11 @@ def main():
|
|||
if coordinator.rank == 0:
|
||||
text = ["Hello my name is"]
|
||||
else:
|
||||
text = ["What's the largest country in the world?", "How many people live in China?", "帮我续写这首诗:离离原上草"]
|
||||
text = [
|
||||
"What's the largest country in the world?",
|
||||
"How many people live in China?",
|
||||
"帮我续写这首诗:离离原上草",
|
||||
]
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
inputs = tokenizer(text, return_tensors="pt", padding=True).to(torch.cuda.current_device())
|
||||
|
||||
|
|
|
@ -50,7 +50,7 @@ def check_mixtral_moe_layer():
|
|||
|
||||
|
||||
def run_dist(rank: int, world_size: int, port: int):
|
||||
colossalai.launch({}, rank, world_size, "localhost", port)
|
||||
colossalai.launch(rank, world_size, "localhost", port)
|
||||
check_mixtral_moe_layer()
|
||||
|
||||
|
||||
|
|
|
@ -133,7 +133,7 @@ def check_mixtral_moe_layer():
|
|||
|
||||
|
||||
def run_dist(rank: int, world_size: int, port: int):
|
||||
colossalai.launch({}, rank, world_size, "localhost", port)
|
||||
colossalai.launch(rank, world_size, "localhost", port)
|
||||
check_mixtral_moe_layer()
|
||||
|
||||
|
||||
|
|
|
@ -145,7 +145,7 @@ def main():
|
|||
args = parse_args()
|
||||
|
||||
# Launch ColossalAI
|
||||
colossalai.launch_from_torch(config={}, seed=args.seed)
|
||||
colossalai.launch_from_torch(seed=args.seed)
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# Set plugin
|
||||
|
@ -195,9 +195,9 @@ def main():
|
|||
lr_scheduler = CosineAnnealingWarmupLR(
|
||||
optimizer=optimizer,
|
||||
total_steps=args.num_epochs * len(dataloader),
|
||||
warmup_steps=args.warmup_steps
|
||||
if args.warmup_steps is not None
|
||||
else int(args.num_epochs * len(dataloader) * 0.025),
|
||||
warmup_steps=(
|
||||
args.warmup_steps if args.warmup_steps is not None else int(args.num_epochs * len(dataloader) * 0.025)
|
||||
),
|
||||
eta_min=0.1 * args.lr,
|
||||
)
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ This directory contains the applications that are powered by Colossal-AI.
|
|||
The list of applications include:
|
||||
|
||||
- [X] [Open-Sora](https://github.com/hpcaitech/Open-Sora): Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models
|
||||
- [X] [Colossal-LLaMA-2](./Colossal-LLaMA-2/): Continual Pre-training of LLaMA-2.
|
||||
- [X] [Colossal-LLaMA](./Colossal-LLaMA/): Continual Pre-training and Supervisied Fine-tuning of LLaMA2 / LLaMA3.
|
||||
- [X] [ColossalEval](./ColossalEval): Evaluation Pipeline for LLMs.
|
||||
- [X] [ColossalChat](./Chat/README.md): Replication of ChatGPT with RLHF.
|
||||
- [X] [FastFold](https://github.com/hpcaitech/FastFold): Optimizing AlphaFold (Biomedicine) Training and Inference on GPU Clusters.
|
||||
|
|
|
@ -246,7 +246,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
|
|||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
class ActivationCheckpointCodeGen(CodeGen):
|
||||
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode:
|
||||
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace, verbose=None) -> PythonCode:
|
||||
free_vars: List[str] = []
|
||||
body: List[str] = []
|
||||
globals_: Dict[str, Any] = {}
|
||||
|
|
|
@ -126,7 +126,7 @@ class AMPOptimizer(OptimizerWrapper):
|
|||
return self.grad_scaler.scale.item()
|
||||
|
||||
def zero_grad(self, *args, **kwargs):
|
||||
self.module.overflow_counter = torch.cuda.IntTensor([0])
|
||||
self.module.overflow_counter = torch.tensor([0], dtype=torch.int, device=get_accelerator().get_current_device())
|
||||
return self.optim.zero_grad(set_to_none=True)
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
|
|
|
@ -4,8 +4,8 @@ from typing import Optional, Set
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.utils import _cast_float
|
||||
from colossalai.zero.legacy.gemini.tensor_utils import free_storage
|
||||
from colossalai.utils import _cast_float, get_current_device
|
||||
from colossalai.utils.common import free_storage
|
||||
|
||||
from .region_manager import RegionManager
|
||||
from .util import GlobalRuntimeInfo
|
||||
|
@ -25,7 +25,7 @@ class BaseOffloadModule:
|
|||
self.model = model
|
||||
self.region_manager = region_manager
|
||||
self.grad_hook_list = []
|
||||
self.overflow_counter = torch.cuda.IntTensor([0])
|
||||
self.overflow_counter = torch.tensor([0], dtype=torch.int, device=get_current_device())
|
||||
|
||||
self.grad_offload_stream = torch.cuda.current_stream() if is_sync else GlobalRuntimeInfo.d2h_stream
|
||||
|
||||
|
|
|
@ -3,7 +3,8 @@ from typing import Dict, List, Tuple
|
|||
import torch
|
||||
from torch.fx import Node
|
||||
|
||||
from colossalai.zero.legacy.gemini.tensor_utils import alloc_storage, free_storage
|
||||
from colossalai.utils.common import free_storage
|
||||
from colossalai.zero.gemini.chunk.chunk import alloc_storage
|
||||
|
||||
|
||||
class Region:
|
||||
|
|
|
@ -372,7 +372,7 @@ if AUTOCHUNK_AVAILABLE:
|
|||
if print_progress:
|
||||
get_logger().info("AutoChunk start codegen")
|
||||
|
||||
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode:
|
||||
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace, verbose=None) -> PythonCode:
|
||||
free_vars: List[str] = []
|
||||
body: List[str] = []
|
||||
globals_: Dict[str, Any] = {}
|
||||
|
|
|
@ -8,9 +8,18 @@ from torch.optim import Optimizer
|
|||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
SUPPORT_PEFT = False
|
||||
try:
|
||||
import peft
|
||||
|
||||
SUPPORT_PEFT = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
import colossalai.interface.pretrained as pretrained_utils
|
||||
from colossalai.checkpoint_io import GeneralCheckpointIO
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.quantization import BnbQuantizationConfig
|
||||
|
||||
from .accelerator import Accelerator
|
||||
from .mixed_precision import MixedPrecision, mixed_precision_factory
|
||||
|
@ -221,6 +230,56 @@ class Booster:
|
|||
assert self.plugin.support_no_sync(), f"The plugin {self.plugin.__class__.__name__} does not support no_sync."
|
||||
return self.plugin.no_sync(model, optimizer)
|
||||
|
||||
def enable_lora(
|
||||
self,
|
||||
model: nn.Module,
|
||||
pretrained_dir: Optional[str] = None,
|
||||
lora_config: "peft.LoraConfig" = None,
|
||||
bnb_quantization_config: Optional[BnbQuantizationConfig] = None,
|
||||
quantize=False,
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Wrap the passed in model with LoRA modules for training. If pretrained directory is provided, lora configs and weights are loaded from that directory.
|
||||
Lora in ColossalAI is implemented using Huggingface peft library, so the arguments for Lora configuration are same as those of peft.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to be appended with LoRA modules.
|
||||
pretrained_dir(str, optional): The path to the pretrained directory, can be a local directory
|
||||
or model_id of a PEFT configuration hosted inside a model repo on the Hugging Face Hub.
|
||||
When set to None, create new lora configs and weights for the model using the passed in lora_config. Defaults to None.
|
||||
lora_config: (peft.LoraConfig, optional): Passed in LoraConfig for peft. Defaults to None.
|
||||
"""
|
||||
if not SUPPORT_PEFT:
|
||||
raise ImportError("Please install Huggingface Peft library to enable lora features in ColossalAI!")
|
||||
|
||||
assert self.plugin is not None, f"Lora can only be enabled when a plugin is provided."
|
||||
assert self.plugin.support_lora(), f"The plugin {self.plugin.__class__.__name__} does not support lora."
|
||||
if pretrained_dir is None:
|
||||
assert (
|
||||
lora_config is not None
|
||||
), "Please provide configuration for Lora when pretrained directory path isn't passed in."
|
||||
assert isinstance(
|
||||
lora_config, peft.LoraConfig
|
||||
), "The passed in configuration should be an instance of peft.LoraConfig."
|
||||
if lora_config is None:
|
||||
assert (
|
||||
pretrained_dir is not None
|
||||
), "Please provide pretrained directory path if not passing in lora configuration."
|
||||
if quantize is True:
|
||||
if bnb_quantization_config is not None:
|
||||
warnings.warn(
|
||||
"User defined BnbQuantizationConfig is not fully tested in ColossalAI. Use it at your own risk."
|
||||
)
|
||||
else:
|
||||
bnb_quantization_config = BnbQuantizationConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
)
|
||||
|
||||
return self.plugin.enable_lora(model, pretrained_dir, lora_config, bnb_quantization_config)
|
||||
|
||||
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None:
|
||||
"""Load model from checkpoint.
|
||||
|
||||
|
@ -323,3 +382,20 @@ class Booster:
|
|||
checkpoint (str): Path to the checkpoint. It must be a local file path.
|
||||
"""
|
||||
self.checkpoint_io.load_lr_scheduler(lr_scheduler, checkpoint)
|
||||
|
||||
def save_lora_as_pretrained(
|
||||
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Save the lora adapters and adapter configuration file to a pretrained checkpoint directory.
|
||||
|
||||
Args:
|
||||
model (Union[nn.Module, ModelWrapper]): A model boosted by Booster.
|
||||
checkpoint (str): Path to the checkpoint directory. It must be a local path.
|
||||
use_safetensors (bool, optional): Whether to use safe tensors when saving. Defaults to False.
|
||||
"""
|
||||
if not SUPPORT_PEFT:
|
||||
raise ImportError("Please install Huggingface Peft library to enable lora features in ColossalAI!")
|
||||
assert self.plugin is not None, f"Lora can only be enabled when a plugin is provided."
|
||||
assert self.plugin.support_lora(), f"The plugin {self.plugin.__class__.__name__} does not support lora."
|
||||
self.checkpoint_io.save_lora_as_pretrained(model, checkpoint, use_safetensors)
|
||||
|
|
|
@ -3,7 +3,7 @@ import logging
|
|||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Callable, Iterator, List, Optional, Tuple
|
||||
from typing import Callable, Dict, Iterator, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -44,10 +44,10 @@ ZERO_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2
|
|||
def get_param_info(optim: Optimizer):
|
||||
# Get a backup of necessary information of parameters for future use, which includes:
|
||||
# 1. A mapping from integer param_id to param32 shape.
|
||||
|
||||
if optim is None:
|
||||
return {}
|
||||
param_info = {"id2shape": {}}
|
||||
|
||||
start_index = 0
|
||||
for group in optim.param_groups:
|
||||
for param_id, param in enumerate(group["params"], start_index):
|
||||
|
@ -424,6 +424,7 @@ class GeminiPlugin(DPPluginBase):
|
|||
)
|
||||
self.extra_dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) if self.extra_dp_size > 1 else None
|
||||
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) if self.tp_size > 1 else None
|
||||
self.dp_size = self.zero_size * self.extra_dp_size
|
||||
|
||||
self.shard_config = ShardConfig(
|
||||
tensor_parallel_process_group=self.tp_group,
|
||||
|
@ -443,6 +444,9 @@ class GeminiPlugin(DPPluginBase):
|
|||
def support_no_sync(self) -> bool:
|
||||
return False
|
||||
|
||||
def support_lora(self) -> bool:
|
||||
return False
|
||||
|
||||
def control_precision(self) -> bool:
|
||||
return True
|
||||
|
||||
|
@ -527,7 +531,7 @@ class GeminiPlugin(DPPluginBase):
|
|||
dataloader: Optional[DataLoader] = None,
|
||||
lr_scheduler: Optional[LRScheduler] = None,
|
||||
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||
optimizer_params_info = get_param_info(optimizer)
|
||||
params_info = get_param_info(optimizer)
|
||||
if not isinstance(model, ModelWrapper):
|
||||
# convert model to sync bn
|
||||
# FIXME(ver217): gemini does not support sync bn
|
||||
|
@ -558,7 +562,7 @@ class GeminiPlugin(DPPluginBase):
|
|||
**self.zero_optim_config,
|
||||
**self.optim_kwargs,
|
||||
tp_group=self.tp_group,
|
||||
optimizer_params_info=optimizer_params_info,
|
||||
params_info=params_info,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
|
||||
|
@ -572,3 +576,8 @@ class GeminiPlugin(DPPluginBase):
|
|||
|
||||
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||
raise NotImplementedError
|
||||
|
||||
def enable_lora(
|
||||
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
|
||||
) -> nn.Module:
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -4,7 +4,7 @@ import warnings
|
|||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from types import MethodType
|
||||
from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, OrderedDict, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -34,7 +34,6 @@ from colossalai.zero.low_level import LowLevelZeroOptimizer
|
|||
|
||||
from .pp_plugin_base import PipelinePluginBase
|
||||
|
||||
DP_AXIS, PP_AXIS, TP_AXIS, SP_AXIS = 0, 1, 2, 3
|
||||
SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"]
|
||||
|
||||
PRECISION_TORCH_TYPE = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}
|
||||
|
@ -213,12 +212,7 @@ def get_param_info(optim: Optimizer):
|
|||
|
||||
if optim is None:
|
||||
return {}
|
||||
param_info = {
|
||||
"param_groups": [],
|
||||
"param2id": {},
|
||||
"id2param": {},
|
||||
"param2shape": {},
|
||||
}
|
||||
param_info = {"param_groups": [], "param2id": {}, "id2param": {}, "param2shape": {}}
|
||||
start_index = 0
|
||||
for group in optim.param_groups:
|
||||
packed_group = {k: v for k, v in group.items() if k != "params"}
|
||||
|
@ -947,6 +941,8 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1.
|
||||
gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None.
|
||||
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
|
||||
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -987,8 +983,11 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
custom_policy: Policy = None,
|
||||
pp_style: str = "1f1b",
|
||||
num_model_chunks: int = 1,
|
||||
num_layers_per_stage: Optional[List[int]] = None,
|
||||
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
|
||||
enable_metadata_cache: bool = True,
|
||||
make_vocab_size_divisible_by: int = 64,
|
||||
dp_outside: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert (
|
||||
|
@ -1036,7 +1035,12 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
self.enable_flash_attention = enable_flash_attention
|
||||
self.enable_jit_fused = enable_jit_fused
|
||||
self.enable_sequence_parallelism = enable_sequence_parallelism
|
||||
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
|
||||
if dp_outside:
|
||||
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
|
||||
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
|
||||
else:
|
||||
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
|
||||
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
|
||||
self.stage_manager = None
|
||||
self.schedule = None
|
||||
self.custom_policy = custom_policy
|
||||
|
@ -1050,9 +1054,10 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism"
|
||||
self.stage_manager = PipelineStageManager(
|
||||
self.pg_mesh,
|
||||
pipeline_axis=PP_AXIS,
|
||||
pipeline_axis=self.pp_axis,
|
||||
enable_interleave=pp_style == "interleaved",
|
||||
num_model_chunks=num_model_chunks,
|
||||
num_layers_per_stage=num_layers_per_stage,
|
||||
)
|
||||
|
||||
if pp_style == "interleaved":
|
||||
|
@ -1074,13 +1079,13 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
|
||||
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
|
||||
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
|
||||
self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
|
||||
self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis)
|
||||
self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis)
|
||||
if self.enable_sequence_parallelism and self.sequence_parallelism_mode in ["split_gather", "ring"]:
|
||||
self.sp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
|
||||
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
|
||||
else:
|
||||
self.sp_group = self.pg_mesh.get_group_along_axis(SP_AXIS)
|
||||
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
|
||||
|
||||
self.shard_config = ShardConfig(
|
||||
tensor_parallel_process_group=self.tp_group,
|
||||
|
@ -1095,6 +1100,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
sequence_parallelism_mode=sequence_parallelism_mode,
|
||||
enable_sequence_overlap=enable_sequence_overlap,
|
||||
parallel_output=parallel_output,
|
||||
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
|
||||
gradient_checkpoint_config=gradient_checkpoint_config,
|
||||
)
|
||||
self.amp_config = dict(
|
||||
|
@ -1150,6 +1156,9 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
def support_no_sync(self) -> bool:
|
||||
return True
|
||||
|
||||
def support_lora(self) -> bool:
|
||||
return False
|
||||
|
||||
def control_checkpoint_io(self) -> bool:
|
||||
return True
|
||||
|
||||
|
@ -1170,7 +1179,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
and self.sequence_parallelism_mode == "all_to_all"
|
||||
)
|
||||
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
|
||||
dp_group = self.pg_mesh.create_group_along_axis([DP_AXIS, SP_AXIS])
|
||||
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
|
||||
else:
|
||||
dp_group = self.dp_group
|
||||
model = HybridParallelModule(
|
||||
|
@ -1318,7 +1327,10 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
_kwargs = kwargs.copy()
|
||||
distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
|
||||
sampler = distributed_sampler_cls(
|
||||
dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle
|
||||
dataset,
|
||||
num_replicas=self.pg_mesh.size(self.dp_axis),
|
||||
rank=self.pg_mesh.coordinate(self.dp_axis),
|
||||
shuffle=shuffle,
|
||||
)
|
||||
|
||||
# Deterministic dataloader
|
||||
|
@ -1347,3 +1359,8 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
self.zero_stage != 2
|
||||
), "ZERO2 is not compatible with no_sync function, please run gradient accumulation with gradient synchronization allowed."
|
||||
return optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
|
||||
|
||||
def enable_lora(
|
||||
self, model: Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
|
||||
) -> Module:
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -1,12 +1,15 @@
|
|||
import enum
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from types import MethodType
|
||||
from typing import Callable, Iterator, List, Optional, Tuple
|
||||
from typing import Callable, Dict, Iterator, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import Parameter
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils._pytree import tree_map
|
||||
|
@ -25,6 +28,7 @@ from colossalai.checkpoint_io.utils import (
|
|||
sharded_optimizer_loading_epilogue,
|
||||
)
|
||||
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
||||
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
||||
from colossalai.zero import LowLevelZeroOptimizer
|
||||
|
||||
from .dp_plugin_base import DPPluginBase
|
||||
|
@ -42,6 +46,12 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
|
|||
SUPPORTED_PRECISION = ["fp16", "bf16", "fp32"]
|
||||
|
||||
|
||||
class OptimizerParamCheckState(enum.Enum):
|
||||
ORIGIN_PARAM_FINDED = 0
|
||||
ORIGIN_PARAM_NOT_FIND = -1
|
||||
LORA_PARM_EXISTED = -2
|
||||
|
||||
|
||||
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
||||
def __init__(self, module: nn.Module, precision: str) -> None:
|
||||
super().__init__(module)
|
||||
|
@ -209,6 +219,19 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|||
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
|
||||
model.update_master_params()
|
||||
|
||||
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
return
|
||||
from peft import PeftModel
|
||||
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||
peft_model = model.unwrap()
|
||||
assert isinstance(
|
||||
peft_model, PeftModel
|
||||
), "The model doesn't have lora adapters, please enable lora before saving."
|
||||
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors)
|
||||
|
||||
|
||||
class LowLevelZeroPlugin(DPPluginBase):
|
||||
"""
|
||||
|
@ -288,6 +311,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
cpu_offload=cpu_offload,
|
||||
master_weights=master_weights,
|
||||
)
|
||||
self.lora_enabled = False
|
||||
self.verbose = verbose
|
||||
|
||||
# set class name with stage, for better error message
|
||||
|
@ -296,6 +320,9 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
def support_no_sync(self) -> bool:
|
||||
return self.stage == 1
|
||||
|
||||
def support_lora(self) -> bool:
|
||||
return False
|
||||
|
||||
def control_precision(self) -> bool:
|
||||
return True
|
||||
|
||||
|
@ -308,6 +335,79 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
def supported_devices(self) -> List[str]:
|
||||
return ["cuda", "npu"]
|
||||
|
||||
def support_lora(self) -> bool:
|
||||
return True
|
||||
|
||||
def enable_lora(
|
||||
self,
|
||||
model: nn.Module,
|
||||
pretrained_dir: Optional[str] = None,
|
||||
lora_config: Optional[Dict] = None,
|
||||
bnb_quantization_config: Optional[BnbQuantizationConfig] = None,
|
||||
) -> nn.Module:
|
||||
from peft import PeftModel, get_peft_model
|
||||
|
||||
assert not isinstance(model, LowLevelZeroModel), "Lora should be enabled before boosting the model."
|
||||
self.lora_enabled = True
|
||||
warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr")
|
||||
|
||||
if bnb_quantization_config is not None:
|
||||
model = quantize_model(model, bnb_quantization_config)
|
||||
|
||||
if pretrained_dir is None:
|
||||
peft_model = get_peft_model(model, lora_config)
|
||||
else:
|
||||
peft_model = PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True)
|
||||
return peft_model
|
||||
|
||||
def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter):
|
||||
origin_param_id = id(origin_param)
|
||||
for group_id, param_group in enumerate(optimizer.param_groups):
|
||||
for p in param_group["params"]:
|
||||
if id(p) == origin_param_id:
|
||||
return group_id
|
||||
return -1
|
||||
|
||||
def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter, lora_param: Parameter):
|
||||
origin_param_id = id(origin_param)
|
||||
lora_param_id = id(lora_param)
|
||||
target_group_id = None
|
||||
for group_id, param_group in enumerate(optimizer.param_groups):
|
||||
for p in param_group["params"]:
|
||||
if id(p) == lora_param_id:
|
||||
# check if the lora parameter exists.
|
||||
return target_group_id, OptimizerParamCheckState.LORA_PARM_EXISTED
|
||||
if id(p) == origin_param_id:
|
||||
target_group_id = group_id
|
||||
if target_group_id is not None:
|
||||
return target_group_id, OptimizerParamCheckState.ORIGIN_PARAM_FINDED
|
||||
else:
|
||||
return target_group_id, OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND
|
||||
|
||||
def add_lora_params_to_optimizer(self, model, optimizer):
|
||||
"""add lora parameters to optimizer"""
|
||||
name2param = {}
|
||||
for name, param in model.named_parameters():
|
||||
name2param[name] = param
|
||||
|
||||
for name, param in name2param.items():
|
||||
if "lora_A" in name or "lora_B" in name:
|
||||
origin_key = name.replace("lora_A.", "")
|
||||
origin_key = origin_key.replace("lora_B.", "")
|
||||
origin_key = origin_key.replace(f"{model.active_adapter}", "base_layer")
|
||||
origin_param = name2param[origin_key]
|
||||
group_id, check_state = self.get_param_group_id(optimizer, origin_param, param)
|
||||
if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND:
|
||||
warnings.warn(
|
||||
"Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups."
|
||||
)
|
||||
elif (
|
||||
check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED
|
||||
and group_id is not None
|
||||
and group_id >= 0
|
||||
):
|
||||
optimizer.param_groups[group_id]["params"].append(param)
|
||||
|
||||
def configure(
|
||||
self,
|
||||
model: nn.Module,
|
||||
|
@ -316,6 +416,15 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
dataloader: Optional[DataLoader] = None,
|
||||
lr_scheduler: Optional[LRScheduler] = None,
|
||||
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||
if self.lora_enabled:
|
||||
from peft import PeftModel
|
||||
|
||||
assert isinstance(
|
||||
model, PeftModel
|
||||
), "The model should have been wrapped as a PeftModel when self.lora_enabled is True"
|
||||
if optimizer is not None:
|
||||
self.add_lora_params_to_optimizer(model, optimizer)
|
||||
|
||||
if not isinstance(model, ModelWrapper):
|
||||
model = LowLevelZeroModel(model, self.precision)
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Iterator, List, Optional, Tuple
|
||||
from typing import Callable, Dict, Iterator, List, Optional, Tuple
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
@ -33,6 +33,10 @@ class Plugin(ABC):
|
|||
def support_no_sync(self) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def support_lora(self) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def configure(
|
||||
self,
|
||||
|
@ -63,6 +67,12 @@ class Plugin(ABC):
|
|||
Context manager to disable gradient synchronization.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def enable_lora(self, model: nn.Module, pretrained_dir: str, lora_config: Dict) -> nn.Module:
|
||||
"""
|
||||
Add LoRA modules to the model passed in. Should only be called in booster.enable_lora().
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def prepare_dataloader(
|
||||
self,
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Callable, Iterator, List, Optional, Tuple
|
||||
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
@ -9,6 +9,8 @@ from torch.utils.data import DataLoader
|
|||
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .dp_plugin_base import DPPluginBase
|
||||
|
||||
|
@ -116,6 +118,22 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
|||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
||||
super().load_sharded_optimizer(optimizer.unwrap(), index_file_path, prefix)
|
||||
|
||||
def save_lora_as_pretrained(
|
||||
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Save the lora adapters and adapter configuration file to checkpoint directory.
|
||||
"""
|
||||
from peft import PeftModel
|
||||
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||
if self.coordinator.is_master():
|
||||
peft_model = model.unwrap()
|
||||
assert isinstance(
|
||||
peft_model, PeftModel
|
||||
), "The model doesn't have lora adapters, please enable lora before saving."
|
||||
peft_model.save_pretrained(save_directory=checkpoint, safe_serialization=use_safetensors)
|
||||
|
||||
|
||||
class TorchDDPModel(ModelWrapper):
|
||||
def __init__(self, module: nn.Module, *args, **kwargs) -> None:
|
||||
|
@ -173,6 +191,9 @@ class TorchDDPPlugin(DPPluginBase):
|
|||
def support_no_sync(self) -> bool:
|
||||
return True
|
||||
|
||||
def support_lora(self) -> bool:
|
||||
return True
|
||||
|
||||
def control_precision(self) -> bool:
|
||||
return False
|
||||
|
||||
|
@ -183,7 +204,7 @@ class TorchDDPPlugin(DPPluginBase):
|
|||
return True
|
||||
|
||||
def supported_devices(self) -> List[str]:
|
||||
return ["cuda"]
|
||||
return ["cuda", "npu"]
|
||||
|
||||
def configure(
|
||||
self,
|
||||
|
@ -194,7 +215,7 @@ class TorchDDPPlugin(DPPluginBase):
|
|||
lr_scheduler: Optional[LRScheduler] = None,
|
||||
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||
# cast model to cuda
|
||||
model = model.cuda()
|
||||
model = model.to(get_current_device())
|
||||
|
||||
# convert model to sync bn
|
||||
model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None)
|
||||
|
@ -216,3 +237,21 @@ class TorchDDPPlugin(DPPluginBase):
|
|||
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||
assert isinstance(model, TorchDDPModel), "Model is not boosted by TorchDDPPlugin."
|
||||
return model.module.no_sync()
|
||||
|
||||
def enable_lora(
|
||||
self,
|
||||
model: nn.Module,
|
||||
pretrained_dir: Optional[str] = None,
|
||||
lora_config: Optional[Dict] = None,
|
||||
bnb_quantization_config: Optional[BnbQuantizationConfig] = None,
|
||||
) -> nn.Module:
|
||||
from peft import PeftModel, get_peft_model
|
||||
|
||||
if bnb_quantization_config is not None:
|
||||
model = quantize_model(model, bnb_quantization_config)
|
||||
|
||||
assert not isinstance(model, TorchDDPModel), "Lora should be enabled before boosting the model."
|
||||
if pretrained_dir is None:
|
||||
return get_peft_model(model, lora_config)
|
||||
else:
|
||||
return PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True)
|
||||
|
|
|
@ -2,7 +2,7 @@ import logging
|
|||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Callable, Iterable, Iterator, List, Optional, Tuple
|
||||
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -318,6 +318,9 @@ class TorchFSDPPlugin(DPPluginBase):
|
|||
def support_no_sync(self) -> bool:
|
||||
return False
|
||||
|
||||
def support_lora(self) -> bool:
|
||||
return False
|
||||
|
||||
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||
raise NotImplementedError("Torch fsdp no_sync func not supported yet.")
|
||||
|
||||
|
@ -361,3 +364,8 @@ class TorchFSDPPlugin(DPPluginBase):
|
|||
|
||||
def get_checkpoint_io(self) -> CheckpointIO:
|
||||
return TorchFSDPCheckpointIO()
|
||||
|
||||
def enable_lora(
|
||||
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
|
||||
) -> nn.Module:
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -335,3 +335,20 @@ class CheckpointIO(ABC):
|
|||
"""
|
||||
state_dict = torch.load(checkpoint)
|
||||
lr_scheduler.load_state_dict(state_dict)
|
||||
|
||||
# ================================================================================
|
||||
# Abstract method for lora saving implementation.
|
||||
# ================================================================================
|
||||
|
||||
@abstractmethod
|
||||
def save_lora_as_pretrained(
|
||||
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Save the lora adapters and adapter configuration file to a pretrained checkpoint directory.
|
||||
|
||||
Args:
|
||||
model (Union[nn.Module, ModelWrapper]): A model boosted by Booster.
|
||||
checkpoint (str): Path to the checkpoint directory. It must be a local path.
|
||||
use_safetensors (bool, optional): Whether to use safe tensors when saving. Defaults to False.
|
||||
"""
|
||||
|
|
|
@ -228,3 +228,6 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
self.__class__.__name__, "\n\t".join(error_msgs)
|
||||
)
|
||||
)
|
||||
|
||||
def save_lora_as_pretrained(self, model: nn.Module, checkpoint: str, use_safetensors: bool = False) -> None:
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -14,6 +14,12 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
|||
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.tensor.padded_tensor import (
|
||||
init_as_padded_tensor,
|
||||
is_padded_tensor,
|
||||
to_padded_tensor,
|
||||
to_unpadded_tensor,
|
||||
)
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .general_checkpoint_io import GeneralCheckpointIO
|
||||
|
@ -32,6 +38,7 @@ from .utils import (
|
|||
save_param_groups,
|
||||
save_state_dict,
|
||||
save_state_dict_shards,
|
||||
search_padding_dim,
|
||||
search_tp_partition_dim,
|
||||
sharded_optimizer_loading_epilogue,
|
||||
)
|
||||
|
@ -89,6 +96,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
if param is None:
|
||||
continue
|
||||
# Gather tensor pieces when using tensor parallel.
|
||||
if is_padded_tensor(param):
|
||||
param = to_unpadded_tensor(param)
|
||||
param_ = gather_distributed_param(param, keep_vars=False)
|
||||
block, block_size = state_dict_sharder.append_param(prefix + name, param_)
|
||||
if block is not None:
|
||||
|
@ -231,7 +240,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
# When pipeline is used, each stage produces its own shard files and index files.
|
||||
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
|
||||
# After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
|
||||
|
||||
final_index_file_path = copy.deepcopy(save_index_file)
|
||||
tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
|
||||
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
|
||||
|
@ -251,6 +259,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
use_safetensors=use_safetensors,
|
||||
use_pp_format=True,
|
||||
)
|
||||
|
||||
if control_saving:
|
||||
assert (
|
||||
self.dp_rank == 0 and self.tp_rank == 0
|
||||
|
@ -867,6 +876,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
dist.all_gather(gather_tensor, v, group=tp_group)
|
||||
v = torch.cat(gather_tensor, dim=partition_dim)
|
||||
|
||||
padding_dim = search_padding_dim(v.shape, original_shape)
|
||||
if padding_dim is not None:
|
||||
v = init_as_padded_tensor(v, v.shape[padding_dim], original_shape[padding_dim], padding_dim)
|
||||
v = to_unpadded_tensor(v)
|
||||
|
||||
state_[k] = v.detach().clone().to(device)
|
||||
|
||||
return state_
|
||||
|
@ -899,6 +913,19 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
# Shard state along tensor parallel group.
|
||||
partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size)
|
||||
global_shape = current_shape
|
||||
if partition_dim is not None:
|
||||
# pad embedding params
|
||||
global_shape = (
|
||||
*current_shape[:partition_dim],
|
||||
current_shape[partition_dim] * self.tp_size,
|
||||
*current_shape[partition_dim + 1 :],
|
||||
)
|
||||
|
||||
padding_dim = search_padding_dim(global_shape, original_shape)
|
||||
if padding_dim is not None:
|
||||
v = to_padded_tensor(v, global_shape[padding_dim], padding_dim)
|
||||
|
||||
if partition_dim is not None:
|
||||
slice_size = current_shape[partition_dim]
|
||||
v = v.split(slice_size, dim=partition_dim)[self.tp_rank]
|
||||
|
|
|
@ -120,6 +120,15 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz
|
|||
return partition_dim
|
||||
|
||||
|
||||
def search_padding_dim(global_shape: torch.Size, original_shape: torch.Size) -> Optional[int]:
|
||||
padding_dim = None
|
||||
for dim, length in enumerate(global_shape):
|
||||
if length > original_shape[dim]:
|
||||
padding_dim = dim
|
||||
break
|
||||
return padding_dim
|
||||
|
||||
|
||||
# ======================================
|
||||
# Helper classes and functions for saving shard file
|
||||
# ======================================
|
||||
|
|
|
@ -1,22 +1,27 @@
|
|||
import threading
|
||||
|
||||
|
||||
class SingletonMeta(type):
|
||||
"""
|
||||
The Singleton class can be implemented in different ways in Python. Some
|
||||
possible methods include: base class, decorator, metaclass. We will use the
|
||||
metaclass because it is best suited for this purpose.
|
||||
Thread-safe Singleton Meta with double-checked locking.
|
||||
Reference: https://en.wikipedia.org/wiki/Double-checked_locking
|
||||
"""
|
||||
|
||||
_instances = {}
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
"""
|
||||
Possible changes to the value of the `__init__` argument do not affect
|
||||
the returned instance.
|
||||
"""
|
||||
# First check (without locking) for performance reasons
|
||||
if cls not in cls._instances:
|
||||
instance = super().__call__(*args, **kwargs)
|
||||
cls._instances[cls] = instance
|
||||
# Acquire a lock before proceeding to the second check
|
||||
with cls._lock:
|
||||
# Second check with lock held to ensure thread safety
|
||||
if cls not in cls._instances:
|
||||
instance = super().__call__(*args, **kwargs)
|
||||
cls._instances[cls] = instance
|
||||
else:
|
||||
assert (
|
||||
len(args) == 0 and len(kwargs) == 0
|
||||
), f"{cls.__name__} is a singleton class and a instance has been created."
|
||||
), f"{cls.__name__} is a singleton class and an instance has been created."
|
||||
|
||||
return cls._instances[cls]
|
||||
|
|
|
@ -625,7 +625,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
|
|||
if CODEGEN_AVAILABLE:
|
||||
|
||||
class ActivationCheckpointCodeGen(CodeGen):
|
||||
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode:
|
||||
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace, verbose=None) -> PythonCode:
|
||||
free_vars: List[str] = []
|
||||
body: List[str] = []
|
||||
globals_: Dict[str, Any] = {}
|
||||
|
|
|
@ -270,7 +270,7 @@ def llama_rmsnorm_forward(
|
|||
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output, residual)
|
||||
|
||||
|
||||
class NopadLlamaMLP(ParallelModule, LlamaMLP):
|
||||
class NopadLlamaMLP(LlamaMLP, ParallelModule):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
|
@ -392,7 +392,7 @@ class NopadLlamaMLP(ParallelModule, LlamaMLP):
|
|||
return f"gate_up_proj MergedLinear1D_Col: in_features={self.gate_up_weight.shape[1]}x2, out_features={self.gate_up_weight.shape[2]}, bias=False"
|
||||
|
||||
|
||||
class NopadLlamaAttention(ParallelModule, LlamaAttention):
|
||||
class NopadLlamaAttention(LlamaAttention, ParallelModule):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
|
|
|
@ -2,20 +2,15 @@
|
|||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.context import Config
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import set_seed
|
||||
|
||||
|
||||
def launch(
|
||||
config: Union[str, Path, Config, Dict],
|
||||
rank: int,
|
||||
world_size: int,
|
||||
host: str,
|
||||
|
@ -44,8 +39,6 @@ def launch(
|
|||
Raises:
|
||||
Exception: Raise exception when config type is wrong
|
||||
"""
|
||||
if rank == 0:
|
||||
warnings.warn("`config` is deprecated and will be removed soon.")
|
||||
|
||||
cur_accelerator = get_accelerator()
|
||||
|
||||
|
@ -68,7 +61,6 @@ def launch(
|
|||
|
||||
|
||||
def launch_from_slurm(
|
||||
config: Union[str, Path, Config, Dict],
|
||||
host: str,
|
||||
port: int,
|
||||
backend: str = "nccl",
|
||||
|
@ -95,7 +87,6 @@ def launch_from_slurm(
|
|||
)
|
||||
|
||||
launch(
|
||||
config=config,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host=host,
|
||||
|
@ -107,7 +98,6 @@ def launch_from_slurm(
|
|||
|
||||
|
||||
def launch_from_openmpi(
|
||||
config: Union[str, Path, Config, Dict],
|
||||
host: str,
|
||||
port: int,
|
||||
backend: str = "nccl",
|
||||
|
@ -135,7 +125,6 @@ def launch_from_openmpi(
|
|||
)
|
||||
|
||||
launch(
|
||||
config=config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
|
@ -147,9 +136,7 @@ def launch_from_openmpi(
|
|||
)
|
||||
|
||||
|
||||
def launch_from_torch(
|
||||
config: Union[str, Path, Config, Dict], backend: str = "nccl", seed: int = 1024, verbose: bool = True
|
||||
):
|
||||
def launch_from_torch(backend: str = "nccl", seed: int = 1024, verbose: bool = True):
|
||||
"""A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size
|
||||
from the environment variables set by PyTorch
|
||||
|
||||
|
@ -171,7 +158,6 @@ def launch_from_torch(
|
|||
)
|
||||
|
||||
launch(
|
||||
config=config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
|
|
|
@ -119,10 +119,6 @@ class FlashAttentionLoader(KernelLoader):
|
|||
]
|
||||
|
||||
|
||||
class FlashAttentionWithPaddingMaskLoader(KernelLoader):
|
||||
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension]
|
||||
|
||||
|
||||
class FlashAttentionWithCustomMaskLoader(KernelLoader):
|
||||
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension]
|
||||
|
||||
|
|
|
@ -56,7 +56,7 @@ class Worker:
|
|||
# initialize a ray collective group, otherwise colossalai distributed env won't be built successfully
|
||||
collective.init_collective_group(world_size, rank, "nccl", "default")
|
||||
# initialize and set distributed environment
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
ray_serve_logger.info(f"Worker with rank {rank} (world size {world_size}) setting up..")
|
||||
log_cuda_info("Worker.setup")
|
||||
|
||||
|
|
|
@ -42,7 +42,7 @@ class CaiInferEngine:
|
|||
import colossalai
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
colossalai.launch_from_torch(config={})
|
||||
colossalai.launch_from_torch()
|
||||
|
||||
model = LlamaForCausalLM.from_pretrained("your_path_to_model")
|
||||
tokenizer = LlamaTokenizer.from_pretrained("/home/lczyh/share/models/llama-7b-hf")
|
||||
|
|
|
@ -36,7 +36,7 @@ from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
|
|||
import colossalai
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
colossalai.launch_from_torch(config={})
|
||||
colossalai.launch_from_torch()
|
||||
|
||||
model = LlamaForCausalLM.from_pretrained("/path/to/model")
|
||||
tokenizer = LlamaTokenizer.from_pretrained("/path/to/model")
|
||||
|
@ -57,27 +57,27 @@ We conducted multiple benchmark tests to evaluate the performance. We compared t
|
|||
### Llama Throughput (tokens/s) | input length=1024, output length=128
|
||||
|
||||
#### A10 7b, fp16
|
||||
| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(8) | 32(8) | 32(16)|
|
||||
| :---: | :---: | :---: | :---: | :---: | :---: | :---:|
|
||||
| Pipeline Inference | 40.35 | 77.1 | 139.03 | 232.7 | 257.81 | OOM |
|
||||
| Hugging Face | 41.43 | 65.30 | 91.93 | 114.62 | OOM| OOM |
|
||||
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(8) | 32(16) |
|
||||
|:----------------------------:|:-----:|:-----:|:------:|:------:|:------:|:------:|
|
||||
| Pipeline Inference | 40.35 | 77.1 | 139.03 | 232.7 | 257.81 | OOM |
|
||||
| Hugging Face | 41.43 | 65.30 | 91.93 | 114.62 | OOM | OOM |
|
||||
|
||||
#### A10 13b, fp16
|
||||
| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(4) |
|
||||
| :---: | :---: | :---: | :---: | :---: |
|
||||
| Pipeline Inference | 25.39 | 47.09 | 83.7 | 89.46 |
|
||||
| Hugging Face | 23.48 | 37.59 | 53.44 | OOM |
|
||||
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(4) |
|
||||
|:----------------------------:|:-----:|:-----:|:-----:|:-----:|
|
||||
| Pipeline Inference | 25.39 | 47.09 | 83.7 | 89.46 |
|
||||
| Hugging Face | 23.48 | 37.59 | 53.44 | OOM |
|
||||
|
||||
|
||||
#### A800 7b, fp16
|
||||
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) |
|
||||
| :---: | :---: | :---: | :---: | :---: | :---: |
|
||||
| Pipeline Inference| 57.97 | 110.13 | 213.33 | 389.86 | 670.12 |
|
||||
| Hugging Face | 42.44 | 76.5 | 151.97 | 212.88 | 256.13 |
|
||||
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) |
|
||||
|:----------------------------:|:-----:|:------:|:------:|:------:|:------:|
|
||||
| Pipeline Inference | 57.97 | 110.13 | 213.33 | 389.86 | 670.12 |
|
||||
| Hugging Face | 42.44 | 76.5 | 151.97 | 212.88 | 256.13 |
|
||||
|
||||
|
||||
#### A800 13b, fp16
|
||||
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) |
|
||||
| :---: | :---: | :---: | :---: | :---: | :---: |
|
||||
| Pipeline Inference | 41.78 | 94.18 | 172.67| 310.75| 470.15 |
|
||||
| Hugging Face | 36.57 | 68.4 | 105.81 | 139.51 | 166.34 |
|
||||
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) |
|
||||
|:----------------------------:|:-----:|:-----:|:------:|:------:|:------:|
|
||||
| Pipeline Inference | 41.78 | 94.18 | 172.67 | 310.75 | 470.15 |
|
||||
| Hugging Face | 36.57 | 68.4 | 105.81 | 139.51 | 166.34 |
|
||||
|
|
|
@ -12,7 +12,7 @@ from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
|
|||
GIGABYTE = 1024**3
|
||||
MEGABYTE = 1024 * 1024
|
||||
|
||||
colossalai.launch_from_torch(config={})
|
||||
colossalai.launch_from_torch()
|
||||
|
||||
|
||||
def data_gen(batch_size: int = 4, seq_len: int = 512):
|
||||
|
|
|
@ -56,7 +56,7 @@ class Worker:
|
|||
# initialize a ray collective group, otherwise colossalai distributed env won't be built successfully
|
||||
collective.init_collective_group(world_size, rank, "nccl", "default")
|
||||
# initialize and set distributed environment
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
ray_serve_logger.info(f"Worker with rank {rank} (world size {world_size}) setting up..")
|
||||
log_cuda_info("Worker.setup")
|
||||
|
||||
|
|
|
@ -98,7 +98,7 @@ class ColossalInferenceHandler(BaseHandler, ABC):
|
|||
self.model.cuda()
|
||||
self.model.eval()
|
||||
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host=host, port=port, backend="nccl")
|
||||
colossalai.launch(rank=rank, world_size=world_size, host=host, port=port, backend="nccl")
|
||||
logger.info("Initializing TPInferEngine ...")
|
||||
shard_config = ShardConfig(
|
||||
enable_tensor_parallelism=True if self.tp_size > 1 else False, extra_kwargs={"inference_only": True}
|
||||
|
|
|
@ -114,7 +114,7 @@ def run_worker(rank, args, master_func):
|
|||
port = args.master_port
|
||||
backend = "nccl" if device == "cuda" else "gloo"
|
||||
|
||||
launch(dict(), rank, world_size, host, int(port), backend, verbose=False)
|
||||
launch(rank, world_size, host, int(port), backend, verbose=False)
|
||||
ppg.set_global_info(
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
|
|
|
@ -8,7 +8,7 @@ Licensed under the MIT License.
|
|||
"""
|
||||
import torch
|
||||
|
||||
from colossalai.utils import multi_tensor_applier
|
||||
from colossalai.utils import get_current_device, multi_tensor_applier
|
||||
|
||||
|
||||
class FusedAdam(torch.optim.Optimizer):
|
||||
|
@ -75,7 +75,7 @@ class FusedAdam(torch.optim.Optimizer):
|
|||
fused_optim = FusedOptimizerLoader().load()
|
||||
|
||||
# Skip buffer
|
||||
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=get_current_device())
|
||||
self.multi_tensor_adam = fused_optim.multi_tensor_adam
|
||||
else:
|
||||
raise RuntimeError("FusedAdam requires cuda extensions")
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import Any, Optional
|
|||
import torch
|
||||
|
||||
from colossalai.kernel.kernel_loader import FusedOptimizerLoader
|
||||
from colossalai.utils import multi_tensor_applier
|
||||
from colossalai.utils import get_current_device, multi_tensor_applier
|
||||
|
||||
from .cpu_adam import CPUAdam
|
||||
|
||||
|
@ -87,7 +87,7 @@ class HybridAdam(CPUAdam):
|
|||
if torch.cuda.is_available():
|
||||
fused_optim = FusedOptimizerLoader().load()
|
||||
self.gpu_adam_op = fused_optim.multi_tensor_adam
|
||||
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=get_current_device())
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None, div_scale: float = -1):
|
||||
|
|
|
@ -45,6 +45,18 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
|
|||
return unpickle
|
||||
|
||||
|
||||
def check_for_nccl_backend(group):
|
||||
pg = group or c10d._get_default_group()
|
||||
# Gate PG wrapper check on Gloo availability.
|
||||
if c10d._GLOO_AVAILABLE:
|
||||
# It is not expected for PG to be wrapped many times, but support it just
|
||||
# in case
|
||||
while isinstance(pg, c10d._ProcessGroupWrapper):
|
||||
pg = pg.wrapped_pg
|
||||
|
||||
return c10d.is_nccl_available() and pg.name() == c10d.Backend.NCCL
|
||||
|
||||
|
||||
# NOTE: FIXME: NPU DOES NOT support isend nor irecv, so broadcast is kept for future use
|
||||
def _broadcast_object_list(
|
||||
object_list: List[Any], src: int, group: ProcessGroup, device: Optional[Union[torch.device, str, int]] = None
|
||||
|
|
|
@ -7,7 +7,7 @@ from torch.nn import Module
|
|||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.utils import get_current_device
|
||||
|
@ -327,7 +327,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
self.send_forward(output_obj)
|
||||
|
||||
if outputs is not None:
|
||||
outputs = merge_batch(outputs)
|
||||
if isinstance(model, ModelWrapper):
|
||||
model = model.unwrap()
|
||||
batch_size_dim = getattr(model, "batch_size_dim", 0)
|
||||
outputs = merge_batch(outputs, batch_size_dim)
|
||||
return {"loss": accum_loss, "outputs": outputs}
|
||||
|
||||
def run_forward_backward(
|
||||
|
@ -410,7 +413,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
|
||||
|
||||
if outputs is not None:
|
||||
outputs = merge_batch(outputs)
|
||||
if isinstance(model, ModelWrapper):
|
||||
model = model.unwrap()
|
||||
batch_size_dim = getattr(model, "batch_size_dim", 0)
|
||||
outputs = merge_batch(outputs, batch_size_dim)
|
||||
return {"loss": accum_loss, "outputs": outputs}
|
||||
|
||||
def forward_backward_step(
|
||||
|
|
|
@ -27,16 +27,18 @@ class PipelineStageManager:
|
|||
pipeline_axis: int,
|
||||
enable_interleave: bool = False,
|
||||
num_model_chunks: int = 1,
|
||||
num_layers_per_stage: Optional[List[int]] = None,
|
||||
) -> None:
|
||||
assert enable_interleave or num_model_chunks == 1, "num_model_chunks must be 1 when enable_interleave is False"
|
||||
|
||||
self.num_layers_per_stage = None
|
||||
|
||||
self.pg_mesh = pg_mesh
|
||||
self.pipeline_axis = pipeline_axis
|
||||
self.prev_rank: Optional[Tuple[int, ...]] = None
|
||||
self.next_rank: Optional[Tuple[int, ...]] = None
|
||||
self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {}
|
||||
if num_layers_per_stage is not None:
|
||||
assert len(num_layers_per_stage) == self.num_stages
|
||||
self.num_layers_per_stage = num_layers_per_stage
|
||||
|
||||
# init prev and next coord
|
||||
coord = self.pg_mesh.coordinate()
|
||||
|
@ -56,6 +58,8 @@ class PipelineStageManager:
|
|||
self.p2p_groups[tuple(ranks_in_group)] = group
|
||||
|
||||
self.is_interleave = enable_interleave
|
||||
# for interleaved pipeline parallel, each device is responsible for multiple chunk of layers
|
||||
self.num_model_chunks: int = num_model_chunks
|
||||
if enable_interleave:
|
||||
# use circle p2p communication
|
||||
# add the process group of the first rank and the last rank
|
||||
|
@ -64,59 +68,11 @@ class PipelineStageManager:
|
|||
ranks_in_group = self.pg_mesh.get_ranks_in_group(group)
|
||||
self.p2p_groups[tuple(ranks_in_group)] = group
|
||||
|
||||
# for interleaved pipeline parallel, each device is responsible for multiple chunk of layers
|
||||
self.num_model_chunks: int = num_model_chunks
|
||||
|
||||
# for shardformer, hold stage indices of model
|
||||
self.stage_indices: List[Tuple[int, int]]
|
||||
# for shardformer, hold model chunk id
|
||||
self.model_chunk_id: Optional[int] = None
|
||||
|
||||
@property
|
||||
def control_distribute_layers(self) -> bool:
|
||||
return self.num_layers_per_stage is not None
|
||||
|
||||
def set_distribution_config(self, num_model_layers: int, num_layers_per_stage: List[int]) -> None:
|
||||
"""Set the distribution configuration.
|
||||
This allows user to customize the number of layers for each stage.
|
||||
|
||||
Args:
|
||||
num_model_layers (int): Number of layers in the model.
|
||||
num_layers_per_stage (List[int]): Number of layers for each stage.
|
||||
"""
|
||||
assert all([0 < num_layers < num_model_layers for num_layers in num_layers_per_stage])
|
||||
assert sum(num_layers_per_stage) == num_model_layers
|
||||
assert len(num_layers_per_stage) == self.num_stages * (self.num_model_chunks if self.is_interleave else 1)
|
||||
self.num_model_layers = num_model_layers
|
||||
self.num_layers_per_stage = num_layers_per_stage
|
||||
|
||||
def distribute_layers(
|
||||
self, num_layers: int, num_stages: Optional[int] = None, num_model_chunks: Optional[int] = None
|
||||
) -> List[int]:
|
||||
"""Divide layers into stages"""
|
||||
num_stages = self.num_stages if num_stages is None else num_stages
|
||||
num_model_chunks = (
|
||||
(self.num_model_chunks if self.is_interleave else 1) if num_model_chunks is None else num_model_chunks
|
||||
)
|
||||
|
||||
if self.control_distribute_layers:
|
||||
assert num_layers == self.num_model_layers
|
||||
return self.num_layers_per_stage
|
||||
|
||||
else:
|
||||
quotient = num_layers // (num_stages * num_model_chunks)
|
||||
remainder = num_layers % (num_stages * num_model_chunks)
|
||||
|
||||
# calculate the num_layers per stage
|
||||
layers_per_stage = [quotient] * num_stages * num_model_chunks
|
||||
|
||||
# deal with the rest layers
|
||||
if remainder > 0:
|
||||
start_position = (num_stages * num_model_chunks) // 2 - remainder // 2
|
||||
for i in range(start_position, start_position + remainder):
|
||||
layers_per_stage[i] += 1
|
||||
return layers_per_stage
|
||||
|
||||
def get_stage_index(
|
||||
self,
|
||||
layers_per_stage: List[int],
|
||||
|
@ -139,9 +95,7 @@ class PipelineStageManager:
|
|||
|
||||
"""
|
||||
stage = self.stage if stage is None else stage
|
||||
num_model_chunks = (
|
||||
(self.num_model_chunks if self.is_interleave else 1) if num_model_chunks is None else num_model_chunks
|
||||
)
|
||||
num_model_chunks = self.num_model_chunks if num_model_chunks is None else num_model_chunks
|
||||
num_stages = self.num_stages if num_stages is None else num_stages
|
||||
|
||||
num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)
|
||||
|
@ -261,3 +215,25 @@ class PipelineStageManager:
|
|||
self.model_chunk_id = model_chunk_id
|
||||
yield
|
||||
self.model_chunk_id = old_model_chunk_id
|
||||
|
||||
def distribute_layers(
|
||||
self, num_layers: int, num_stages: Optional[int] = None, num_model_chunks: Optional[int] = None
|
||||
) -> List[int]:
|
||||
if self.num_layers_per_stage is not None:
|
||||
assert sum(self.num_layers_per_stage) == num_layers
|
||||
return self.num_layers_per_stage
|
||||
|
||||
num_stages = self.num_stages if num_stages is None else num_stages
|
||||
num_model_chunks = self.num_model_chunks if num_model_chunks is None else num_model_chunks
|
||||
quotient = num_layers // (num_stages * num_model_chunks)
|
||||
remainder = num_layers % (num_stages * num_model_chunks)
|
||||
|
||||
# calculate the num_layers per stage
|
||||
layers_per_stage = [quotient] * num_stages * num_model_chunks
|
||||
|
||||
# deal with the rest layers
|
||||
if remainder > 0:
|
||||
start_position = (num_stages * num_model_chunks) // 2 - remainder // 2
|
||||
for i in range(start_position, start_position + remainder):
|
||||
layers_per_stage[i] += 1
|
||||
return layers_per_stage
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
from .bnb import quantize_model
|
||||
from .bnb_config import BnbQuantizationConfig
|
||||
|
||||
__all__ = [
|
||||
"BnbQuantizationConfig",
|
||||
"quantize_model",
|
||||
]
|
|
@ -0,0 +1,321 @@
|
|||
# adapted from Hugging Face accelerate/utils/bnb.py accelerate/utils/modeling.py
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .bnb_config import BnbQuantizationConfig
|
||||
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
|
||||
IS_4BIT_BNB_AVAILABLE = bnb.__version__ >= "0.39.0"
|
||||
IS_8BIT_BNB_AVAILABLE = bnb.__version__ >= "0.37.2"
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def quantize_model(
|
||||
model: torch.nn.Module,
|
||||
bnb_quantization_config: BnbQuantizationConfig,
|
||||
):
|
||||
"""
|
||||
This function will quantize the input loaded model with the associated config passed in `bnb_quantization_config`.
|
||||
We will quantize the model and put the model on the GPU.
|
||||
|
||||
Args:
|
||||
model (`torch.nn.Module`):
|
||||
Input model. The model already loaded
|
||||
bnb_quantization_config (`BnbQuantizationConfig`):
|
||||
The bitsandbytes quantization parameters
|
||||
|
||||
Returns:
|
||||
`torch.nn.Module`: The quantized model
|
||||
"""
|
||||
|
||||
load_in_4bit = bnb_quantization_config.load_in_4bit
|
||||
load_in_8bit = bnb_quantization_config.load_in_8bit
|
||||
|
||||
if load_in_8bit and not IS_8BIT_BNB_AVAILABLE:
|
||||
raise ImportError(
|
||||
"You have a version of `bitsandbytes` that is not compatible with 8bit quantization,"
|
||||
" make sure you have the latest version of `bitsandbytes` installed."
|
||||
)
|
||||
if load_in_4bit and not IS_4BIT_BNB_AVAILABLE:
|
||||
raise ValueError(
|
||||
"You have a version of `bitsandbytes` that is not compatible with 4bit quantization,"
|
||||
"make sure you have the latest version of `bitsandbytes` installed."
|
||||
)
|
||||
|
||||
# We keep some modules such as the lm_head in their original dtype for numerical stability reasons
|
||||
if bnb_quantization_config.skip_modules is None:
|
||||
bnb_quantization_config.skip_modules = get_keys_to_not_convert(model)
|
||||
|
||||
modules_to_not_convert = bnb_quantization_config.skip_modules
|
||||
|
||||
# We add the modules we want to keep in full precision
|
||||
if bnb_quantization_config.keep_in_fp32_modules is None:
|
||||
bnb_quantization_config.keep_in_fp32_modules = []
|
||||
keep_in_fp32_modules = bnb_quantization_config.keep_in_fp32_modules
|
||||
|
||||
# compatibility with peft
|
||||
model.is_loaded_in_4bit = load_in_4bit
|
||||
model.is_loaded_in_8bit = load_in_8bit
|
||||
|
||||
# assert model_device is cuda
|
||||
model_device = next(model.parameters()).device
|
||||
|
||||
model = replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=modules_to_not_convert)
|
||||
|
||||
# convert param to the right dtype
|
||||
dtype = bnb_quantization_config.torch_dtype
|
||||
for name, param in model.state_dict().items():
|
||||
if any(module_to_keep_in_fp32 in name for module_to_keep_in_fp32 in keep_in_fp32_modules):
|
||||
param.to(torch.float32)
|
||||
if param.dtype != torch.float32:
|
||||
name = name.replace(".weight", "").replace(".bias", "")
|
||||
param = getattr(model, name, None)
|
||||
if param is not None:
|
||||
param.to(torch.float32)
|
||||
elif torch.is_floating_point(param):
|
||||
param.to(dtype)
|
||||
if model_device.type == "cuda":
|
||||
# move everything to cpu in the first place because we can't do quantization if the weights are already on cuda
|
||||
model.cuda(torch.cuda.current_device())
|
||||
torch.cuda.empty_cache()
|
||||
elif torch.cuda.is_available():
|
||||
model.to(torch.cuda.current_device())
|
||||
logger.info(
|
||||
f"The model device type is {model_device.type}. However, cuda is needed for quantization."
|
||||
"We move the model to cuda."
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
|
||||
return model
|
||||
|
||||
|
||||
def replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=None, current_key_name=None):
|
||||
"""
|
||||
A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules or by `bnb.nn.Linear4bit`
|
||||
modules from the `bitsandbytes`library. The function will be run recursively and replace `torch.nn.Linear` modules.
|
||||
|
||||
Parameters:
|
||||
model (`torch.nn.Module`):
|
||||
Input model or `torch.nn.Module` as the function is run recursively.
|
||||
modules_to_not_convert (`List[str]`):
|
||||
Names of the modules to not quantize convert. In practice we keep the `lm_head` in full precision for
|
||||
numerical stability reasons.
|
||||
current_key_name (`List[str]`, *optional*):
|
||||
An array to track the current key of the recursion. This is used to check whether the current key (part of
|
||||
it) is not in the list of modules to not convert.
|
||||
"""
|
||||
|
||||
if modules_to_not_convert is None:
|
||||
modules_to_not_convert = []
|
||||
|
||||
model, has_been_replaced = _replace_with_bnb_layers(
|
||||
model, bnb_quantization_config, modules_to_not_convert, current_key_name
|
||||
)
|
||||
if not has_been_replaced:
|
||||
logger.warning(
|
||||
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
|
||||
" this can happen for some architectures such as gpt2 that uses Conv1D instead of Linear layers."
|
||||
" Please double check your model architecture, or submit an issue on github if you think this is"
|
||||
" a bug."
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def _replace_with_bnb_layers(
|
||||
model,
|
||||
bnb_quantization_config,
|
||||
modules_to_not_convert=None,
|
||||
current_key_name=None,
|
||||
):
|
||||
"""
|
||||
Private method that wraps the recursion for module replacement.
|
||||
|
||||
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
|
||||
"""
|
||||
# bitsandbytes will initialize CUDA on import, so it needs to be imported lazily
|
||||
|
||||
has_been_replaced = False
|
||||
for name, module in model.named_children():
|
||||
if current_key_name is None:
|
||||
current_key_name = []
|
||||
current_key_name.append(name)
|
||||
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
|
||||
# Check if the current key is not in the `modules_to_not_convert`
|
||||
current_key_name_str = ".".join(current_key_name)
|
||||
proceed = True
|
||||
for key in modules_to_not_convert:
|
||||
if (
|
||||
(key in current_key_name_str) and (key + "." in current_key_name_str)
|
||||
) or key == current_key_name_str:
|
||||
proceed = False
|
||||
break
|
||||
if proceed:
|
||||
# Load bnb module with empty weight and replace ``nn.Linear` module
|
||||
if bnb_quantization_config.load_in_8bit:
|
||||
bnb_module = bnb.nn.Linear8bitLt(
|
||||
module.in_features,
|
||||
module.out_features,
|
||||
module.bias is not None,
|
||||
has_fp16_weights=False,
|
||||
threshold=bnb_quantization_config.llm_int8_threshold,
|
||||
)
|
||||
elif bnb_quantization_config.load_in_4bit:
|
||||
bnb_module = bnb.nn.Linear4bit(
|
||||
module.in_features,
|
||||
module.out_features,
|
||||
module.bias is not None,
|
||||
bnb_quantization_config.bnb_4bit_compute_dtype,
|
||||
compress_statistics=bnb_quantization_config.bnb_4bit_use_double_quant,
|
||||
quant_type=bnb_quantization_config.bnb_4bit_quant_type,
|
||||
)
|
||||
else:
|
||||
raise ValueError("load_in_8bit and load_in_4bit can't be both False")
|
||||
bnb_module.weight.data = module.weight.data
|
||||
bnb_module.weight.skip_zero_check = True
|
||||
if module.bias is not None:
|
||||
bnb_module.bias.data = module.bias.data
|
||||
bnb_module.bias.skip_zero_check = True
|
||||
bnb_module.requires_grad_(False)
|
||||
setattr(model, name, bnb_module)
|
||||
has_been_replaced = True
|
||||
if len(list(module.children())) > 0:
|
||||
_, _has_been_replaced = _replace_with_bnb_layers(
|
||||
module, bnb_quantization_config, modules_to_not_convert, current_key_name
|
||||
)
|
||||
has_been_replaced = has_been_replaced | _has_been_replaced
|
||||
# Remove the last key for recursion
|
||||
current_key_name.pop(-1)
|
||||
return model, has_been_replaced
|
||||
|
||||
|
||||
def get_keys_to_not_convert(model):
|
||||
r"""
|
||||
An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules
|
||||
we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want
|
||||
to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in
|
||||
int8.
|
||||
|
||||
Parameters:
|
||||
model (`torch.nn.Module`):
|
||||
Input model
|
||||
"""
|
||||
# Create a copy of the model
|
||||
# with init_empty_weights():
|
||||
# tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager`
|
||||
tied_model = model
|
||||
|
||||
tied_params = find_tied_parameters(tied_model)
|
||||
# For compatibility with Accelerate < 0.18
|
||||
if isinstance(tied_params, dict):
|
||||
tied_keys = sum(list(tied_params.values()), []) + list(tied_params.keys())
|
||||
else:
|
||||
tied_keys = sum(tied_params, [])
|
||||
has_tied_params = len(tied_keys) > 0
|
||||
|
||||
# Check if it is a base model
|
||||
is_base_model = False
|
||||
if hasattr(model, "base_model_prefix"):
|
||||
is_base_model = not hasattr(model, model.base_model_prefix)
|
||||
|
||||
# Ignore this for base models (BertModel, GPT2Model, etc.)
|
||||
if (not has_tied_params) and is_base_model:
|
||||
return []
|
||||
|
||||
# otherwise they have an attached head
|
||||
list_modules = list(model.named_children())
|
||||
list_last_module = [list_modules[-1][0]]
|
||||
|
||||
# add last module together with tied weights
|
||||
intersection = set(list_last_module) - set(tied_keys)
|
||||
list_untouched = list(set(tied_keys)) + list(intersection)
|
||||
|
||||
# remove ".weight" from the keys
|
||||
names_to_remove = [".weight", ".bias"]
|
||||
filtered_module_names = []
|
||||
for name in list_untouched:
|
||||
for name_to_remove in names_to_remove:
|
||||
if name_to_remove in name:
|
||||
name = name.replace(name_to_remove, "")
|
||||
filtered_module_names.append(name)
|
||||
|
||||
return filtered_module_names
|
||||
|
||||
|
||||
def find_tied_parameters(model: nn.Module, **kwargs):
|
||||
"""
|
||||
Find the tied parameters in a given model.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
The signature accepts keyword arguments, but they are for the recursive part of this function and you should ignore
|
||||
them.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
model (`torch.nn.Module`): The model to inspect.
|
||||
|
||||
Returns:
|
||||
List[List[str]]: A list of lists of parameter names being all tied together.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
>>> from collections import OrderedDict
|
||||
>>> import torch.nn as nn
|
||||
|
||||
>>> model = nn.Sequential(OrderedDict([("linear1", nn.Linear(4, 4)), ("linear2", nn.Linear(4, 4))]))
|
||||
>>> model.linear2.weight = model.linear1.weight
|
||||
>>> find_tied_parameters(model)
|
||||
[['linear1.weight', 'linear2.weight']]
|
||||
```
|
||||
"""
|
||||
# Initialize result and named_parameters before recursing.
|
||||
named_parameters = kwargs.get("named_parameters", None)
|
||||
prefix = kwargs.get("prefix", "")
|
||||
result = kwargs.get("result", {})
|
||||
|
||||
if named_parameters is None:
|
||||
named_parameters = {n: p for n, p in model.named_parameters()}
|
||||
else:
|
||||
# A tied parameter will not be in the full `named_parameters` seen above but will be in the `named_parameters`
|
||||
# of the submodule it belongs to. So while recursing we track the names that are not in the initial
|
||||
# `named_parameters`.
|
||||
for name, parameter in model.named_parameters():
|
||||
full_name = name if prefix == "" else f"{prefix}.{name}"
|
||||
if full_name not in named_parameters:
|
||||
# When we find one, it has to be one of the existing parameters.
|
||||
for new_name, new_param in named_parameters.items():
|
||||
if new_param is parameter:
|
||||
if new_name not in result:
|
||||
result[new_name] = []
|
||||
result[new_name].append(full_name)
|
||||
|
||||
# Once we have treated direct parameters, we move to the child modules.
|
||||
for name, child in model.named_children():
|
||||
child_name = name if prefix == "" else f"{prefix}.{name}"
|
||||
find_tied_parameters(child, named_parameters=named_parameters, prefix=child_name, result=result)
|
||||
|
||||
return FindTiedParametersResult([sorted([weight] + list(set(tied))) for weight, tied in result.items()])
|
||||
|
||||
|
||||
class FindTiedParametersResult(list):
|
||||
"""
|
||||
This is a subclass of a list to handle backward compatibility for Transformers. Do not rely on the fact this is not
|
||||
a list or on the `values` method as in the future this will be removed.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def values(self):
|
||||
return sum([x[1:] for x in self], [])
|
|
@ -0,0 +1,113 @@
|
|||
# adapted from Hugging Face accelerate/utils/dataclasses.py
|
||||
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class BnbQuantizationConfig:
|
||||
"""
|
||||
A plugin to enable BitsAndBytes 4bit and 8bit quantization
|
||||
"""
|
||||
|
||||
load_in_8bit: bool = field(default=False, metadata={"help": "enable 8bit quantization."})
|
||||
|
||||
llm_int8_threshold: float = field(
|
||||
default=6.0, metadata={"help": "value of the outliner threshold. only relevant when load_in_8bit=True"}
|
||||
)
|
||||
|
||||
load_in_4bit: bool = field(default=False, metadata={"help": "enable 4bit quantization."})
|
||||
|
||||
bnb_4bit_quant_type: str = field(
|
||||
default="fp4",
|
||||
metadata={
|
||||
"help": "set the quantization data type in the `bnb.nn.Linear4Bit` layers. Options are {'fp4','np4'}."
|
||||
},
|
||||
)
|
||||
|
||||
bnb_4bit_use_double_quant: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "enable nested quantization where the quantization constants from the first quantization are quantized again."
|
||||
},
|
||||
)
|
||||
|
||||
bnb_4bit_compute_dtype: bool = field(
|
||||
default="fp16",
|
||||
metadata={
|
||||
"help": "This sets the computational type which might be different than the input time. For example, inputs might be "
|
||||
"fp32, but computation can be set to bf16 for speedups. Options are {'fp32','fp16','bf16'}."
|
||||
},
|
||||
)
|
||||
|
||||
torch_dtype: torch.dtype = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "this sets the dtype of the remaining non quantized layers. `bitsandbytes` library suggests to set the value"
|
||||
"to `torch.float16` for 8 bit model and use the same dtype as the compute dtype for 4 bit model "
|
||||
},
|
||||
)
|
||||
|
||||
skip_modules: List[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "an explicit list of the modules that we don't quantize. The dtype of these modules will be `torch_dtype`."
|
||||
},
|
||||
)
|
||||
|
||||
keep_in_fp32_modules: List[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "an explicit list of the modules that we don't quantize. We keep them in `torch.float32`."},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.bnb_4bit_compute_dtype, str):
|
||||
if self.bnb_4bit_compute_dtype == "fp32":
|
||||
self.bnb_4bit_compute_dtype = torch.float32
|
||||
elif self.bnb_4bit_compute_dtype == "fp16":
|
||||
self.bnb_4bit_compute_dtype = torch.float16
|
||||
elif self.bnb_4bit_compute_dtype == "bf16":
|
||||
self.bnb_4bit_compute_dtype = torch.bfloat16
|
||||
else:
|
||||
raise ValueError(
|
||||
f"bnb_4bit_compute_dtype must be in ['fp32','fp16','bf16'] but found {self.bnb_4bit_compute_dtype}"
|
||||
)
|
||||
elif not isinstance(self.bnb_4bit_compute_dtype, torch.dtype):
|
||||
raise ValueError("bnb_4bit_compute_dtype must be a string or a torch.dtype")
|
||||
|
||||
if self.skip_modules is not None and not isinstance(self.skip_modules, list):
|
||||
raise ValueError("skip_modules must be a list of strings")
|
||||
|
||||
if self.keep_in_fp32_modules is not None and not isinstance(self.keep_in_fp32_modules, list):
|
||||
raise ValueError("keep_in_fp_32_modules must be a list of strings")
|
||||
|
||||
if self.load_in_4bit:
|
||||
self.target_dtype = "int4"
|
||||
|
||||
if self.load_in_8bit:
|
||||
self.target_dtype = torch.int8
|
||||
|
||||
if self.load_in_4bit and self.llm_int8_threshold != 6.0:
|
||||
warnings.warn("llm_int8_threshold can only be used for model loaded in 8bit")
|
||||
|
||||
if isinstance(self.torch_dtype, str):
|
||||
if self.torch_dtype == "fp32":
|
||||
self.torch_dtype = torch.float32
|
||||
elif self.torch_dtype == "fp16":
|
||||
self.torch_dtype = torch.float16
|
||||
elif self.torch_dtype == "bf16":
|
||||
self.torch_dtype = torch.bfloat16
|
||||
else:
|
||||
raise ValueError(f"torch_dtype must be in ['fp32','fp16','bf16'] but found {self.torch_dtype}")
|
||||
|
||||
if self.load_in_8bit and self.torch_dtype is None:
|
||||
self.torch_dtype = torch.float16
|
||||
|
||||
if self.load_in_4bit and self.torch_dtype is None:
|
||||
self.torch_dtype = self.bnb_4bit_compute_dtype
|
||||
|
||||
if not isinstance(self.torch_dtype, torch.dtype):
|
||||
raise ValueError("torch_dtype must be a torch.dtype")
|
|
@ -38,7 +38,7 @@ from transformers import BertForMaskedLM
|
|||
import colossalai
|
||||
|
||||
# launch colossalai
|
||||
colossalai.launch_from_torch(config={})
|
||||
colossalai.launch_from_torch()
|
||||
|
||||
# create model
|
||||
config = BertConfig.from_pretrained('bert-base-uncased')
|
||||
|
@ -114,30 +114,30 @@ We will follow this roadmap to develop Shardformer:
|
|||
- [x] Unit Testing
|
||||
- [ ] Policy Implementation
|
||||
|
||||
| model | tensor parallel | pipeline parallel | lazy initialization | xformer | flash attn2 | jit fused operator | fused layernorm | sequence parallel | overlap |
|
||||
| :------: | :-----: | :-----: | :--------: | :---------: | :------: | :-----: | :-----: | :--------: | :---------: |
|
||||
| bert | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
|
||||
| t5 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] |
|
||||
| llama V1/V2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] |
|
||||
| gpt2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
|
||||
| opt | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] |
|
||||
| bloom | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
|
||||
| chatglm2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
|
||||
| vit | [√] | [√] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
|
||||
| whisper | [√] | [√] | [√] | [√] | [√] | [ ] | [√] | [ ] | [ ] |
|
||||
| sam | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
|
||||
| blip2 | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
|
||||
| falcon | [√] | [√] | [√] | [√] | [√] | [ ] | [√] | [ ] | [ ] |
|
||||
| roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
| albert | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
| ernie | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
| gpt-neo | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
| gpt-j | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
| beit | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
| swin | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
| swin V2 | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
| qwen | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
| mistral | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
|
||||
| model | tensor parallel | pipeline parallel | lazy initialization | xformer | flash attn2 | jit fused operator | fused layernorm | sequence parallel | overlap |
|
||||
|:-----------:|:---------------:|:-----------------:|:-------------------:|:-------:|:-----------:|:------------------:|:---------------:|:-----------------:|:-------:|
|
||||
| bert | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
|
||||
| t5 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] |
|
||||
| llama V1/V2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] |
|
||||
| gpt2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
|
||||
| opt | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] |
|
||||
| bloom | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
|
||||
| chatglm2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
|
||||
| vit | [√] | [√] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
|
||||
| whisper | [√] | [√] | [√] | [√] | [√] | [ ] | [√] | [ ] | [ ] |
|
||||
| sam | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
|
||||
| blip2 | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
|
||||
| falcon | [√] | [√] | [√] | [√] | [√] | [ ] | [√] | [ ] | [ ] |
|
||||
| roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
| albert | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
| ernie | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
| gpt-neo | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
| gpt-j | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
| beit | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
| swin | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
| swin V2 | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
| qwen | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
| mistral | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
|
||||
|
||||
|
||||
## 💡 API Design
|
||||
|
@ -391,6 +391,43 @@ _POLICY_LIST = {
|
|||
}
|
||||
```
|
||||
|
||||
#### How to support those models in huggingface model hub but not in the transformers library
|
||||
|
||||
There are two cases:
|
||||
|
||||
1. the modeling file is in the `transformers` library but the model weight is not in the `transformers` library. E.g. model structure of "01-ai/Yi-34B" is the same as LLaMA but the weight is not in the `transformers` library. In this case, we should support llama as usual and Yi-34B is also supported by the llama policy. We do not need to add a new policy for Yi-34B.
|
||||
2. the modeling file is not in the `transformers` library, such as the "THUDM/chatglm2-6b".
|
||||
|
||||
Take "THUDM/chatglm2-6b" as an example, we clearly illustrate how to support this model in the `shardformer`.
|
||||
|
||||
Unlike llama which is in `transformers` library, we cannot import chatglm2 model directly. Thus, the key in policy should be str of class name, rather than class itself.
|
||||
|
||||
E.g. for llama:
|
||||
```python
|
||||
policy[LlamaDecoderLayer] = ModulePolicyDescription(...)
|
||||
```
|
||||
|
||||
for chatglm2:
|
||||
```python
|
||||
policy["GLMBlock"] = ModulePolicyDescription(...)
|
||||
```
|
||||
|
||||
Then when registering such models in the autopolicy, we should follow below format:
|
||||
```python
|
||||
"transformers_modules.<modeling_filename>.<class_name>": PolicyLocation(
|
||||
file_name="<policy_filename>", class_name="<policy_class_name>"
|
||||
)
|
||||
```
|
||||
|
||||
As for chatglm2 model, it should be:
|
||||
```python
|
||||
"transformers_modules.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
|
||||
file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"
|
||||
)
|
||||
```
|
||||
|
||||
When using such models, `AutoModel` is supported as usual. The policy will be automatically loaded by the autopolicy.
|
||||
|
||||
### Write Your Unit Testing
|
||||
|
||||
This section serves as the guideline for testing the `shardformer` module.
|
||||
|
@ -424,13 +461,13 @@ We conducted [benchmark tests](./examples/performance_benchmark.py) to evaluate
|
|||
We set the batch size to 4, the number of attention heads to 8, and the head dimension to 64. 'N_CTX' refers to the sequence length.
|
||||
|
||||
In the case of using 2 GPUs, the training times are as follows.
|
||||
| N_CTX | org_model | shard_model |
|
||||
| :------: | :-----: | :-----: |
|
||||
| 256 | 11.2ms | 17.2ms |
|
||||
| 512 | 9.8ms | 19.5ms |
|
||||
| 1024 | 19.6ms | 18.9ms |
|
||||
| 2048 | 46.6ms | 30.8ms |
|
||||
| 4096 | 160.5ms | 90.4ms |
|
||||
| N_CTX | org_model | shard_model |
|
||||
|:-----:|:---------:|:-----------:|
|
||||
| 256 | 11.2ms | 17.2ms |
|
||||
| 512 | 9.8ms | 19.5ms |
|
||||
| 1024 | 19.6ms | 18.9ms |
|
||||
| 2048 | 46.6ms | 30.8ms |
|
||||
| 4096 | 160.5ms | 90.4ms |
|
||||
|
||||
|
||||
<p align="center">
|
||||
|
@ -440,13 +477,13 @@ In the case of using 2 GPUs, the training times are as follows.
|
|||
|
||||
In the case of using 4 GPUs, the training times are as follows.
|
||||
|
||||
| N_CTX | org_model | shard_model |
|
||||
| :------: | :-----: | :-----: |
|
||||
| 256 | 10.0ms | 21.1ms |
|
||||
| 512 | 11.5ms | 20.2ms |
|
||||
| 1024 | 22.1ms | 20.6ms |
|
||||
| 2048 | 46.9ms | 24.8ms |
|
||||
| 4096 | 160.4ms | 68.0ms |
|
||||
| N_CTX | org_model | shard_model |
|
||||
|:-----:|:---------:|:-----------:|
|
||||
| 256 | 10.0ms | 21.1ms |
|
||||
| 512 | 11.5ms | 20.2ms |
|
||||
| 1024 | 22.1ms | 20.6ms |
|
||||
| 2048 | 46.9ms | 24.8ms |
|
||||
| 4096 | 160.4ms | 68.0ms |
|
||||
|
||||
|
||||
|
||||
|
@ -475,10 +512,10 @@ warmup_fraction = 0.03
|
|||
|
||||
|
||||
| accuracy | f1 | loss | GPU number | model sharded |
|
||||
| :------: | :-----: | :-----: | :--------: | :---------: |
|
||||
| 0.82971 | 0.87713 | 0.23194 | 4 | True |
|
||||
| 0.83797 | 0.88006 | 0.22683 | 2 | True |
|
||||
| 0.84521 | 0.88700 | 0.21822 | 1 | False |
|
||||
|:--------:|:-------:|:-------:|:----------:|:-------------:|
|
||||
| 0.82971 | 0.87713 | 0.23194 | 4 | True |
|
||||
| 0.83797 | 0.88006 | 0.22683 | 2 | True |
|
||||
| 0.84521 | 0.88700 | 0.21822 | 1 | False |
|
||||
|
||||
|
||||
Overall, the results demonstrate that using shardformers during model training does not affect the convergence.
|
||||
|
|
|
@ -28,7 +28,7 @@ def to_device(x: Any, device: torch.device) -> Any:
|
|||
|
||||
|
||||
def train(args):
|
||||
colossalai.launch_from_torch(config={}, seed=42)
|
||||
colossalai.launch_from_torch(seed=42)
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# prepare for data and dataset
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Shardformer Benchmark
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
|
@ -84,5 +85,5 @@ def bench_shardformer(BATCH, N_CTX, provider, model_func, dtype=torch.float32, d
|
|||
# start benchmark, command:
|
||||
# torchrun --standalone --nproc_per_node=2 performance_benchmark.py
|
||||
if __name__ == "__main__":
|
||||
colossalai.launch_from_torch({})
|
||||
colossalai.launch_from_torch()
|
||||
bench_shardformer.run(save_path=".", print_data=dist.get_rank() == 0)
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
from ._operation import all_to_all_comm
|
||||
from .attn import AttnMaskType, ColoAttention
|
||||
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
|
||||
from .embedding import Embedding1D, VocabParallelEmbedding1D
|
||||
from .linear import Linear1D_Col, Linear1D_Row
|
||||
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
|
||||
from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D
|
||||
from .loss import cross_entropy_1d
|
||||
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
|
||||
from .parallel_module import ParallelModule
|
||||
|
@ -25,6 +25,9 @@ __all__ = [
|
|||
"FusedRMSNorm",
|
||||
"FusedLinear1D_Col",
|
||||
"ParallelModule",
|
||||
"PaddingEmbedding",
|
||||
"PaddingLMHead",
|
||||
"VocabParallelLMHead1D",
|
||||
"AttnMaskType",
|
||||
"ColoAttention",
|
||||
"all_to_all_comm",
|
||||
|
|
|
@ -8,7 +8,6 @@ from colossalai.kernel.kernel_loader import (
|
|||
FlashAttentionForFloatAndCustomMaskLoader,
|
||||
FlashAttentionLoader,
|
||||
FlashAttentionWithCustomMaskLoader,
|
||||
FlashAttentionWithPaddingMaskLoader,
|
||||
KernelLoader,
|
||||
)
|
||||
|
||||
|
@ -65,15 +64,17 @@ class ColoAttention:
|
|||
half_dispatch_map = {
|
||||
None: FlashAttentionLoader(),
|
||||
AttnMaskType.CUSTOM: FlashAttentionWithCustomMaskLoader(),
|
||||
AttnMaskType.PADDED: FlashAttentionWithPaddingMaskLoader(),
|
||||
AttnMaskType.PADDED: FlashAttentionLoader(),
|
||||
AttnMaskType.CAUSAL: FlashAttentionLoader(),
|
||||
AttnMaskType.PADDED_CAUSAL: FlashAttentionWithPaddingMaskLoader(),
|
||||
AttnMaskType.PADDED_CAUSAL: FlashAttentionLoader(),
|
||||
}
|
||||
# fp32
|
||||
float_dispatch_map = {
|
||||
None: FlashAttentionForFloatAndCustomMaskLoader(),
|
||||
AttnMaskType.CUSTOM: FlashAttentionForFloatAndCustomMaskLoader(),
|
||||
AttnMaskType.PADDED: FlashAttentionForFloatAndCustomMaskLoader(),
|
||||
AttnMaskType.CAUSAL: FlashAttentionForFloatAndCustomMaskLoader(),
|
||||
AttnMaskType.PADDED_CAUSAL: FlashAttentionForFloatAndCustomMaskLoader(),
|
||||
}
|
||||
ColoAttention._kernel_dispatch_map = {
|
||||
torch.float16: half_dispatch_map,
|
||||
|
@ -140,16 +141,22 @@ class ColoAttention:
|
|||
outputs["attention_mask_type"] = AttnMaskType.CAUSAL
|
||||
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device).tril(diagonal=0).expand(b, s_q, s_kv)
|
||||
else:
|
||||
assert q_padding_mask.shape == (
|
||||
b,
|
||||
s_q,
|
||||
), f"q_padding_mask shape {q_padding_mask.shape} should be the same. ({shape_4d})"
|
||||
max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask)
|
||||
if kv_padding_mask is None:
|
||||
# self attention
|
||||
kv_padding_mask = q_padding_mask
|
||||
assert q_padding_mask.shape == (b, s_q) and kv_padding_mask.shape == (
|
||||
max_seqlen_kv, cu_seqlens_kv, kv_indices = max_seqlen_q, cu_seqlens_q, q_indices
|
||||
else:
|
||||
max_seqlen_kv, cu_seqlens_kv, kv_indices = get_pad_info(kv_padding_mask)
|
||||
assert kv_padding_mask.shape == (
|
||||
b,
|
||||
s_kv,
|
||||
), f"q_padding_mask shape {q_padding_mask.shape} and kv_padding_mask shape {kv_padding_mask.shape} should be the same. ({shape_4d})"
|
||||
attention_mask = torch.einsum("bi,bj->bij", q_padding_mask, kv_padding_mask).to(dtype=dtype, device=device)
|
||||
max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask)
|
||||
max_seqlen_kv, cu_seqlens_kv, kv_indices = get_pad_info(kv_padding_mask)
|
||||
), f"q_padding_mask shape {kv_padding_mask.shape} should be the same. ({shape_4d})"
|
||||
attention_mask = q_padding_mask[:, None, :].expand(b, s_kv, s_q).to(dtype=dtype, device=device)
|
||||
outputs.update(
|
||||
{
|
||||
"cu_seqlens_q": cu_seqlens_q,
|
||||
|
|
|
@ -21,10 +21,10 @@ from colossalai.tensor.d_tensor.api import (
|
|||
)
|
||||
|
||||
from ._operation import gather_forward_split_backward, reduce_forward
|
||||
from .parallel_module import ParallelModule
|
||||
from .parallel_module import PaddingParallelModule, ParallelModule
|
||||
from .utils import create_randomizer_with_offset
|
||||
|
||||
__all__ = ["Embedding1D", "VocabParallelEmbedding1D"]
|
||||
__all__ = ["Embedding1D", "VocabParallelEmbedding1D", "PaddingEmbedding"]
|
||||
|
||||
|
||||
class Embedding1D(ParallelModule):
|
||||
|
@ -161,7 +161,80 @@ class Embedding1D(ParallelModule):
|
|||
return output_parallel
|
||||
|
||||
|
||||
class VocabParallelEmbedding1D(ParallelModule):
|
||||
class PaddingEmbedding(PaddingParallelModule):
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: int = None,
|
||||
dtype: torch.dtype = None,
|
||||
device: torch.device = None,
|
||||
weight: Optional[nn.Parameter] = None,
|
||||
make_vocab_size_divisible_by: int = 64,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
self.embed_args = args
|
||||
self.embed_kwargs = kwargs
|
||||
self.padding_idx = padding_idx
|
||||
if num_embeddings % make_vocab_size_divisible_by != 0:
|
||||
self.num_embeddings = (
|
||||
num_embeddings + make_vocab_size_divisible_by - (num_embeddings % make_vocab_size_divisible_by)
|
||||
)
|
||||
# create weight and bias
|
||||
if weight is None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs))
|
||||
else:
|
||||
weight.data = weight.data.to(device=device, dtype=dtype)
|
||||
|
||||
super().__init__(self.num_embeddings, num_embeddings, weight)
|
||||
|
||||
if weight is None:
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
init.normal_(self.weight)
|
||||
self._fill_padding_idx_with_zero()
|
||||
|
||||
def _fill_padding_idx_with_zero(self) -> None:
|
||||
if self.padding_idx is not None:
|
||||
with torch.no_grad():
|
||||
self.weight[self.padding_idx].fill_(0)
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
return F.embedding(input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
) -> PaddingParallelModule:
|
||||
r"""
|
||||
Convert a native pytorch embedding module to a parallel module.
|
||||
"""
|
||||
LazyInitContext.materialize(module)
|
||||
# get the origin attributes
|
||||
num_embeddings = module.num_embeddings
|
||||
embedding_dim = module.embedding_dim
|
||||
padding_idx = module.padding_idx
|
||||
device = module.weight.device
|
||||
# create the parallel module
|
||||
padding_embedding = PaddingEmbedding(
|
||||
num_embeddings=num_embeddings,
|
||||
embedding_dim=embedding_dim,
|
||||
padding_idx=padding_idx,
|
||||
device=device,
|
||||
weight=module.weight,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return padding_embedding
|
||||
|
||||
|
||||
class VocabParallelEmbedding1D(PaddingParallelModule):
|
||||
r"""Embedding parallelized in the vocabulary dimension.
|
||||
|
||||
Args:
|
||||
|
@ -201,10 +274,10 @@ class VocabParallelEmbedding1D(ParallelModule):
|
|||
process_group: ProcessGroup = None,
|
||||
weight: Optional[nn.Parameter] = None,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
make_vocab_size_divisible_by: int = 64,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
self.embed_args = args
|
||||
|
@ -214,8 +287,23 @@ class VocabParallelEmbedding1D(ParallelModule):
|
|||
tensor_parallel_size = dist.get_world_size(group=process_group)
|
||||
tensor_parallel_rank = dist.get_rank(group=process_group)
|
||||
|
||||
self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size)
|
||||
self.num_embeddings = self.num_embeddings_per_partition
|
||||
# generate weight and bias
|
||||
if weight is None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs))
|
||||
else:
|
||||
weight.data = weight.data.to(device=device, dtype=dtype)
|
||||
|
||||
# calculate new padding size
|
||||
multiple = make_vocab_size_divisible_by * tensor_parallel_size
|
||||
if num_embeddings % multiple != 0:
|
||||
self.num_embeddings = num_embeddings + multiple - (num_embeddings % multiple)
|
||||
|
||||
# resize vocabulary size
|
||||
super().__init__(self.num_embeddings, num_embeddings, weight)
|
||||
|
||||
# deal with tensor parallelism
|
||||
self.num_embeddings_per_partition = divide(self.num_embeddings, tensor_parallel_size)
|
||||
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
|
||||
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
|
||||
|
||||
|
@ -226,13 +314,6 @@ class VocabParallelEmbedding1D(ParallelModule):
|
|||
seed = torch.random.initial_seed()
|
||||
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
|
||||
|
||||
# parameter
|
||||
if weight is None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
self.weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs))
|
||||
else:
|
||||
weight.data = weight.data.to(device=device, dtype=dtype)
|
||||
self.weight = weight
|
||||
if not is_distributed_tensor(self.weight):
|
||||
sharded_weight = shard_rowwise(self.weight.data, process_group)
|
||||
sharded_tensor_to_existing_param(sharded_weight, self.weight)
|
||||
|
@ -243,7 +324,7 @@ class VocabParallelEmbedding1D(ParallelModule):
|
|||
@staticmethod
|
||||
def from_native_module(
|
||||
module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
) -> ParallelModule:
|
||||
) -> PaddingParallelModule:
|
||||
r"""
|
||||
Convert a native pytorch embedding module to a parallel module.
|
||||
"""
|
||||
|
@ -303,11 +384,9 @@ class VocabParallelEmbedding1D(ParallelModule):
|
|||
# Mask the input.
|
||||
masked_input = input_.clone() - self.vocab_start_index
|
||||
masked_input[input_mask] = 0
|
||||
|
||||
output_parallel = F.embedding(
|
||||
masked_input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs
|
||||
)
|
||||
|
||||
# Mask the output embedding.
|
||||
embedding_output = output_parallel.clone()
|
||||
embedding_output[input_mask, :] = 0.0
|
||||
|
|
|
@ -32,7 +32,7 @@ from ._operation import (
|
|||
reducescatter_forward_gather_backward,
|
||||
split_forward_gather_backward,
|
||||
)
|
||||
from .parallel_module import ParallelModule
|
||||
from .parallel_module import PaddingParallelModule, ParallelModule
|
||||
from .utils import create_randomizer_with_offset
|
||||
|
||||
__all__ = ["Linear1D_Col", "Linear1D_Row"]
|
||||
|
@ -84,8 +84,9 @@ class Linear1D_Col(ParallelModule):
|
|||
bias_: Optional[Parameter] = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
super().__init__(weight=weight, bias_=bias_, **kwargs)
|
||||
|
||||
# Keep input parameters
|
||||
self.in_features = in_features
|
||||
|
@ -118,6 +119,7 @@ class Linear1D_Col(ParallelModule):
|
|||
else:
|
||||
weight.data = weight.data.to(device=device, dtype=dtype)
|
||||
self.weight = weight
|
||||
|
||||
if not is_distributed_tensor(self.weight):
|
||||
sharded_weight = shard_rowwise(self.weight.data, self.process_group)
|
||||
sharded_tensor_to_existing_param(sharded_weight, self.weight)
|
||||
|
@ -140,7 +142,7 @@ class Linear1D_Col(ParallelModule):
|
|||
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs
|
||||
) -> ParallelModule:
|
||||
r"""
|
||||
Convert a native PyTorch linear layer to a parallelized linear layer.
|
||||
|
@ -173,7 +175,6 @@ class Linear1D_Col(ParallelModule):
|
|||
process_group=process_group,
|
||||
weight=module.weight,
|
||||
bias_=module.bias,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
@ -322,7 +323,7 @@ class Linear1D_Row(ParallelModule):
|
|||
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs
|
||||
) -> ParallelModule:
|
||||
r"""
|
||||
Convert a native PyTorch linear layer to a parallelized linear layer.
|
||||
|
@ -356,7 +357,6 @@ class Linear1D_Row(ParallelModule):
|
|||
process_group=process_group,
|
||||
weight=module.weight,
|
||||
bias_=module.bias,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
@ -439,3 +439,211 @@ class Linear1D_Row(ParallelModule):
|
|||
return output
|
||||
else:
|
||||
return output, self.bias
|
||||
|
||||
|
||||
class PaddingLMHead(PaddingParallelModule):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
device: torch.device = None,
|
||||
weight: Optional[Parameter] = None,
|
||||
bias_: Optional[Parameter] = None,
|
||||
make_vocab_size_divisible_by: int = 64,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
):
|
||||
# Keep input parameters
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
|
||||
if out_features % make_vocab_size_divisible_by != 0:
|
||||
self.out_features = (
|
||||
out_features + make_vocab_size_divisible_by - (out_features % make_vocab_size_divisible_by)
|
||||
)
|
||||
if weight is None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
weight = Parameter(torch.empty(out_features, self.in_features, **factory_kwargs))
|
||||
else:
|
||||
weight.data = weight.data.to(device=device, dtype=dtype)
|
||||
|
||||
if bias:
|
||||
if bias_ is None:
|
||||
self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
|
||||
else:
|
||||
bias_.data = bias_.data.to(device=device, dtype=dtype)
|
||||
else:
|
||||
bias_ = None
|
||||
|
||||
# resize embeddings
|
||||
super().__init__(self.out_features, out_features, weight, bias_)
|
||||
|
||||
if weight is None:
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
fan_in, fan_out = self.in_features, self.out_features
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs
|
||||
) -> PaddingParallelModule:
|
||||
r"""
|
||||
Convert a native PyTorch linear layer to a parallelized linear layer.
|
||||
"""
|
||||
LazyInitContext.materialize(module)
|
||||
# get the attributes
|
||||
in_features = module.in_features
|
||||
out_features = module.out_features
|
||||
bias = module.bias is not None
|
||||
device = module.weight.device
|
||||
# ensure only one process group is passed
|
||||
|
||||
lm_head_linear = PaddingLMHead(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias=bias,
|
||||
device=device,
|
||||
weight=module.weight,
|
||||
bias_=module.bias,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return lm_head_linear
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
output = F.linear(input, self.weight, self.bias)
|
||||
output = output[..., : self.old_num_embeddings]
|
||||
return output
|
||||
|
||||
|
||||
class VocabParallelLMHead1D(Linear1D_Col, PaddingParallelModule):
|
||||
r"""Linear layer with column parallelism.
|
||||
|
||||
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
|
||||
its second dimension as :math:`A = [A_1, ..., A_p]`.
|
||||
|
||||
Args:
|
||||
in_features (int): size of each input sample.
|
||||
out_features (int): size of each output sample.
|
||||
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
|
||||
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
|
||||
device (`torch.device`): The device of parameters, defaults to None.
|
||||
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
|
||||
gather_output (bool, optional): If true, call all-gather on output and make Y available
|
||||
to all GPUs, otherwise, every GPU will have its output
|
||||
which is :math:`Y_i = XA_i`, defaults to False
|
||||
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
|
||||
overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False.
|
||||
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
|
||||
which is preserved for kernel fusion, defaults to False
|
||||
weight_initializer (`typing.Callable`):
|
||||
The initializer of weight, defaults to kaiming uniform initializer.
|
||||
bias_initializer (`typing.Callable`):
|
||||
The initializer of bias, defaults to xavier uniform initializer.
|
||||
|
||||
More details about ``initializer`` please refer to
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
device: torch.device = None,
|
||||
process_group: ProcessGroup = None,
|
||||
weight: Optional[Parameter] = None,
|
||||
bias_: Optional[Parameter] = None,
|
||||
make_vocab_size_divisible_by: int = 64,
|
||||
**kwargs,
|
||||
):
|
||||
# create weight and bias
|
||||
if weight is None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
weight = Parameter(torch.empty(out_features, self.in_features, **factory_kwargs))
|
||||
if bias:
|
||||
if bias_ is None:
|
||||
bias_ = Parameter(torch.empty(out_features, **factory_kwargs))
|
||||
else:
|
||||
bias_ = None
|
||||
|
||||
# calculate new vocab size
|
||||
self.tensor_parallel_size = dist.get_world_size(group=process_group)
|
||||
new_out_features = out_features
|
||||
multiple = make_vocab_size_divisible_by * self.tensor_parallel_size
|
||||
if out_features % multiple != 0:
|
||||
new_out_features = out_features + multiple - (out_features % multiple)
|
||||
|
||||
super().__init__(
|
||||
in_features=in_features,
|
||||
out_features=new_out_features,
|
||||
bias=bias,
|
||||
device=device,
|
||||
process_group=process_group,
|
||||
weight=weight,
|
||||
bias_=bias_,
|
||||
**kwargs,
|
||||
new_num_embeddings=new_out_features,
|
||||
old_num_embeddings=out_features,
|
||||
)
|
||||
# get the length of valid embeddings
|
||||
tp_rank = dist.get_rank(process_group)
|
||||
partition_size = self.new_num_embeddings // dist.get_world_size(process_group)
|
||||
if self.old_num_embeddings >= (tp_rank + 1) * partition_size:
|
||||
self.num_valid_embeddings_local = partition_size
|
||||
elif self.old_num_embeddings >= tp_rank * partition_size:
|
||||
self.num_valid_embeddings_local = self.old_num_embeddings - tp_rank * partition_size
|
||||
else:
|
||||
self.num_valid_embeddings_local = 0
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs
|
||||
) -> PaddingParallelModule:
|
||||
r"""
|
||||
Convert a native PyTorch linear layer to a parallelized linear layer.
|
||||
"""
|
||||
LazyInitContext.materialize(module)
|
||||
# get the attributes
|
||||
in_features = module.in_features
|
||||
out_features = module.out_features
|
||||
bias = module.bias is not None
|
||||
device = module.weight.device
|
||||
|
||||
lm_head_linear = VocabParallelLMHead1D(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias=bias,
|
||||
device=device,
|
||||
process_group=process_group,
|
||||
weight=module.weight,
|
||||
bias_=module.bias,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return lm_head_linear
|
||||
|
||||
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
# get forward output
|
||||
if self.skip_bias_add:
|
||||
output, bias = super().forward(input_)
|
||||
else:
|
||||
output = super().forward(input_)
|
||||
|
||||
# delete the padding of output
|
||||
if self.gather_output:
|
||||
output = output[..., : self.old_num_embeddings]
|
||||
else:
|
||||
output = output[..., : self.num_valid_embeddings_local]
|
||||
|
||||
# return
|
||||
if self.skip_bias_add:
|
||||
return output, bias
|
||||
return output
|
||||
|
|
|
@ -15,7 +15,14 @@ class DistCrossEntropy(Function):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor, ignore_index: int, process_group: ProcessGroup):
|
||||
def forward(
|
||||
ctx,
|
||||
vocab_logits: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
ignore_index: int,
|
||||
process_group: ProcessGroup,
|
||||
vocab_size: int,
|
||||
):
|
||||
r"""
|
||||
Calculate the cross entropy loss before gather, the origin loss function is as follows:
|
||||
loss = -log(exp(x[class])/sum(exp(x[i]))
|
||||
|
@ -41,15 +48,21 @@ class DistCrossEntropy(Function):
|
|||
vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1)
|
||||
|
||||
# mask the target in the local device
|
||||
partition_vocab_size = vocab_logits.size()[-1]
|
||||
rank = dist.get_rank(group=process_group)
|
||||
world_size = dist.get_world_size(group=process_group)
|
||||
global_vocab_size = partition_vocab_size * world_size
|
||||
if vocab_size == None:
|
||||
partition_vocab_size = vocab_logits.size()[-1]
|
||||
global_vocab_size = partition_vocab_size * world_size
|
||||
else:
|
||||
global_vocab_size = vocab_size
|
||||
partition_vocab_size = global_vocab_size // world_size
|
||||
|
||||
# [down, up) => false, other device and -100 => true
|
||||
delta = (global_vocab_size + world_size - 1) // world_size
|
||||
down_threshold = rank * delta
|
||||
up_threshold = down_threshold + delta
|
||||
if up_threshold > global_vocab_size:
|
||||
up_threshold = global_vocab_size
|
||||
mask = (target < down_threshold) | (target >= up_threshold)
|
||||
masked_target = target.clone() - down_threshold
|
||||
masked_target[mask] = 0
|
||||
|
@ -57,7 +70,8 @@ class DistCrossEntropy(Function):
|
|||
# reshape the logits and target
|
||||
# reshape the vocab_logits to [bath_size * seq_len, vocab_size]
|
||||
# reshape the labels to [bath_size * seq_len]
|
||||
logits_2d = vocab_logits.view(-1, partition_vocab_size)
|
||||
self_vocab_size = vocab_logits.size()[-1]
|
||||
logits_2d = vocab_logits.view(-1, self_vocab_size)
|
||||
masked_target_1d = masked_target.view(-1)
|
||||
|
||||
# extract the x[class] and set the x[other device] to zero
|
||||
|
@ -104,10 +118,14 @@ class DistCrossEntropy(Function):
|
|||
grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update
|
||||
|
||||
grad_logits.mul_(grad_output.unsqueeze(dim=-1))
|
||||
return grad_logits, None, None, None
|
||||
return grad_logits, None, None, None, None
|
||||
|
||||
|
||||
def cross_entropy_1d(
|
||||
vocab_logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100, process_group: ProcessGroup = None
|
||||
vocab_logits: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
ignore_index: int = -100,
|
||||
process_group: ProcessGroup = None,
|
||||
vocab_size: int = None,
|
||||
) -> torch.Tensor:
|
||||
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group)
|
||||
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size)
|
||||
|
|
|
@ -225,7 +225,13 @@ class FusedLayerNorm(BaseLayerNorm):
|
|||
# fall back to the normal fused layernorm is not built
|
||||
ApexFusedLayerNorm = FusedLayerNormWithHook
|
||||
else:
|
||||
ApexFusedLayerNorm = FusedLayerNormWithHook
|
||||
try:
|
||||
ApexFusedLayerNorm = FusedLayerNormWithHook
|
||||
except NameError:
|
||||
warnings.warn(
|
||||
"Please install Apex from source to use fused kernels, or set self.enable_fused_normalization = False. Using vanilla layernorm instead."
|
||||
)
|
||||
return module
|
||||
|
||||
layernorm = (
|
||||
ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device)
|
||||
|
@ -275,19 +281,16 @@ class FusedRMSNorm(BaseLayerNorm):
|
|||
)
|
||||
|
||||
LazyInitContext.materialize(module)
|
||||
# to check if it is huggingface LlamaRMSNorm or MistralRMSNorm
|
||||
if module.__class__.__name__ in ["LlamaRMSNorm", "MistralRMSNorm"]:
|
||||
normalized_shape = module.weight.shape[0]
|
||||
eps = module.variance_epsilon
|
||||
elementwise_affine = True
|
||||
else:
|
||||
# get the attributes of the module
|
||||
normalized_shape = module.normalized_shape
|
||||
eps = module.eps
|
||||
elementwise_affine = module.elementwise_affine
|
||||
|
||||
# try to get normalized_shape, eps, elementwise_affine from the module
|
||||
normalized_shape = getattr(module, "normalized_shape", module.weight.shape[0])
|
||||
eps = module.variance_epsilon if hasattr(module, "variance_epsilon") else module.eps
|
||||
elementwise_affine = getattr(module, "elementwise_affine", True)
|
||||
|
||||
rmsnorm = FusedRMSNormWithHook(
|
||||
normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine
|
||||
normalized_shape=normalized_shape,
|
||||
eps=eps,
|
||||
elementwise_affine=elementwise_affine,
|
||||
)
|
||||
|
||||
rmsnorm.weight = module.weight
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
import itertools
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -20,11 +20,15 @@ from colossalai.tensor.d_tensor import (
|
|||
is_distributed_tensor,
|
||||
sharded_tensor_to_param,
|
||||
)
|
||||
from colossalai.tensor.padded_tensor import is_padded_tensor, to_padded_tensor, to_unpadded_tensor
|
||||
|
||||
__all__ = ["ParallelModule"]
|
||||
|
||||
|
||||
class ParallelModule(nn.Module, ABC):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
@abstractmethod
|
||||
def from_native_module(
|
||||
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]] = None
|
||||
|
@ -54,7 +58,7 @@ class ParallelModule(nn.Module, ABC):
|
|||
"""
|
||||
for name, param in self._parameters.items():
|
||||
if param is not None:
|
||||
destination[prefix + name] = gather_distributed_param(param, keep_vars=keep_vars)
|
||||
destination[prefix + name] = gather_distributed_param(param, keep_vars=keep_vars).data
|
||||
|
||||
for name, buf in self._buffers.items():
|
||||
if buf is not None and name not in self._non_persistent_buffers_set:
|
||||
|
@ -171,3 +175,187 @@ class ParallelModule(nn.Module, ABC):
|
|||
input_name = input_name.split(".", 1)[0] # get the name of param/buffer/child
|
||||
if input_name not in self._modules and input_name not in local_state:
|
||||
unexpected_keys.append(key)
|
||||
|
||||
|
||||
class PaddingParallelModule(ParallelModule):
|
||||
def __init__(
|
||||
self,
|
||||
new_num_embeddings: int,
|
||||
old_num_embeddings: int,
|
||||
weight: Optional[nn.Parameter],
|
||||
bias_: Optional[nn.Parameter] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.new_num_embeddings = new_num_embeddings
|
||||
self.old_num_embeddings = old_num_embeddings
|
||||
self.weight = weight
|
||||
self.bias = bias_
|
||||
|
||||
if not (is_distributed_tensor(self.weight) or self.weight.shape[0] == self.new_num_embeddings):
|
||||
self.resize_embedding_weight()
|
||||
|
||||
if self.bias is not None and not (
|
||||
is_distributed_tensor(self.bias) or self.bias.shape[0] == self.new_num_embeddings
|
||||
):
|
||||
self.resize_embedding_bias()
|
||||
|
||||
@abstractmethod
|
||||
def from_native_module(
|
||||
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]] = None
|
||||
) -> "PaddingParallelModule":
|
||||
"""
|
||||
Convert a native PyTorch module to a parallelized module.
|
||||
|
||||
Args:
|
||||
module (nn.Module): the module to be converted.
|
||||
process_group (ProcessGroup or list[ProcessGroup]): the process group(s) to be used for communication.
|
||||
If this is a list, the process group at the ith index of the list will correspond to the process group
|
||||
in the ith axis of the device mesh. Defaults to None, which means the global process group.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
r"""Saves module state to `destination` dictionary, containing a state
|
||||
of the module, but not its descendants. This is called on every
|
||||
submodule in :meth:`~torch.nn.Module.state_dict`.
|
||||
|
||||
In rare cases, subclasses can achieve class-specific behavior by
|
||||
overriding this method with custom logic.
|
||||
|
||||
Args:
|
||||
destination (dict): a dict where state will be stored
|
||||
prefix (str): the prefix for parameters and buffers used in this
|
||||
module
|
||||
"""
|
||||
for name, param in self._parameters.items():
|
||||
if param is not None:
|
||||
param = gather_distributed_param(param, keep_vars=keep_vars)
|
||||
if is_padded_tensor(param):
|
||||
param = to_unpadded_tensor(param)
|
||||
destination[prefix + name] = param.data
|
||||
|
||||
for name, buf in self._buffers.items():
|
||||
if buf is not None and name not in self._non_persistent_buffers_set:
|
||||
destination[prefix + name] = buf if keep_vars else buf.detach()
|
||||
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
||||
if getattr(self.__class__, "get_extra_state", Module.get_extra_state) is not Module.get_extra_state:
|
||||
destination[extra_state_key] = self.get_extra_state()
|
||||
|
||||
def _load_from_state_dict(
|
||||
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
):
|
||||
r"""Copies parameters and buffers from :attr:`state_dict` into only
|
||||
this module, but not its descendants. This is called on every submodule
|
||||
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
|
||||
module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
|
||||
For state dicts without metadata, :attr:`local_metadata` is empty.
|
||||
Subclasses can achieve class-specific backward compatible loading using
|
||||
the version number at `local_metadata.get("version", None)`.
|
||||
|
||||
.. note::
|
||||
:attr:`state_dict` is not the same object as the input
|
||||
:attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
|
||||
it can be modified.
|
||||
|
||||
Args:
|
||||
state_dict (dict): a dict containing parameters and
|
||||
persistent buffers.
|
||||
prefix (str): the prefix for parameters and buffers used in this
|
||||
module
|
||||
local_metadata (dict): a dict containing the metadata for this module.
|
||||
See
|
||||
strict (bool): whether to strictly enforce that the keys in
|
||||
:attr:`state_dict` with :attr:`prefix` match the names of
|
||||
parameters and buffers in this module
|
||||
missing_keys (list of str): if ``strict=True``, add missing keys to
|
||||
this list
|
||||
unexpected_keys (list of str): if ``strict=True``, add unexpected
|
||||
keys to this list
|
||||
error_msgs (list of str): error messages should be added to this
|
||||
list, and will be reported together in
|
||||
:meth:`~torch.nn.Module.load_state_dict`
|
||||
"""
|
||||
for hook in self._load_state_dict_pre_hooks.values():
|
||||
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
|
||||
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
|
||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||
|
||||
for name, param in local_state.items():
|
||||
key = prefix + name
|
||||
|
||||
if key in state_dict:
|
||||
input_param = state_dict[key]
|
||||
if not torch.overrides.is_tensor_like(input_param):
|
||||
error_msgs.append(
|
||||
'While copying the parameter named "{}", '
|
||||
"expected torch.Tensor or Tensor-like object from checkpoint but "
|
||||
"received {}".format(key, type(input_param))
|
||||
)
|
||||
continue
|
||||
|
||||
if is_padded_tensor(param):
|
||||
input_param = to_padded_tensor(input_param, param._current_length, param._padding_dim)
|
||||
|
||||
if is_distributed_tensor(param):
|
||||
# shard the input param
|
||||
device_mesh = get_device_mesh(param)
|
||||
sharding_spec = get_sharding_spec(param)
|
||||
sharded_tensor = distribute_tensor(input_param, device_mesh, sharding_spec)
|
||||
input_param = sharded_tensor_to_param(sharded_tensor)
|
||||
elif is_customized_distributed_tensor(param):
|
||||
input_param = distribute_tensor_with_customization(input_param, param.shard_fn, param.gather_fn)
|
||||
|
||||
# This is used to avoid copying uninitialized parameters into
|
||||
# non-lazy modules, since they dont have the hook to do the checks
|
||||
# in such case, it will error when accessing the .shape attribute.
|
||||
is_param_lazy = torch.nn.parameter.is_lazy(param)
|
||||
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
|
||||
if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
|
||||
input_param = input_param[0]
|
||||
|
||||
if not is_param_lazy and input_param.shape != param.shape:
|
||||
# local shape should match the one in checkpoint
|
||||
error_msgs.append(
|
||||
"size mismatch for {}: copying a param with shape {} from checkpoint, "
|
||||
"the shape in current model is {}.".format(key, input_param.shape, param.shape)
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
with torch.no_grad():
|
||||
param.copy_(input_param)
|
||||
except Exception as ex:
|
||||
error_msgs.append(
|
||||
'While copying the parameter named "{}", '
|
||||
"whose dimensions in the model are {} and "
|
||||
"whose dimensions in the checkpoint are {}, "
|
||||
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
|
||||
)
|
||||
elif strict:
|
||||
missing_keys.append(key)
|
||||
|
||||
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
||||
if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state:
|
||||
if extra_state_key in state_dict:
|
||||
self.set_extra_state(state_dict[extra_state_key])
|
||||
elif strict:
|
||||
missing_keys.append(extra_state_key)
|
||||
elif strict and (extra_state_key in state_dict):
|
||||
unexpected_keys.append(extra_state_key)
|
||||
|
||||
if strict:
|
||||
for key in state_dict.keys():
|
||||
if key.startswith(prefix) and key != extra_state_key:
|
||||
input_name = key[len(prefix) :]
|
||||
input_name = input_name.split(".", 1)[0] # get the name of param/buffer/child
|
||||
if input_name not in self._modules and input_name not in local_state:
|
||||
unexpected_keys.append(key)
|
||||
|
||||
def resize_embedding_weight(self):
|
||||
self.weight = to_padded_tensor(self.weight, self.new_num_embeddings, 0)
|
||||
|
||||
def resize_embedding_bias(self):
|
||||
self.bias = to_padded_tensor(self.bias, self.new_num_embeddings, 0)
|
||||
|
|
|
@ -1287,3 +1287,16 @@ def bert_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||
)
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_jit_fused_bert_intermediate_forward():
|
||||
from transformers.models.bert.modeling_bert import BertIntermediate
|
||||
|
||||
from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction
|
||||
|
||||
def forward(self: BertIntermediate, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, bias = self.dense(hidden_states)
|
||||
hidden_states = JitGeLUFunction.apply(hidden_states, bias)
|
||||
return hidden_states
|
||||
|
||||
return forward
|
||||
|
|
|
@ -129,3 +129,17 @@ def get_jit_fused_blip2_QFormer_output_forward():
|
|||
return hidden_states
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_jit_fused_blip2_mlp_forward():
|
||||
from transformers.models.blip_2.modeling_blip_2 import Blip2MLP
|
||||
|
||||
from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction
|
||||
|
||||
def forward(self: Blip2MLP, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, bias = self.fc1(hidden_states)
|
||||
hidden_states = JitGeLUFunction.apply(hidden_states, bias)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
return forward
|
||||
|
|
|
@ -6,6 +6,7 @@ import torch.distributed as dist
|
|||
from torch.distributed import ProcessGroup
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from torch.nn import functional as F
|
||||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
CausalLMOutputWithCrossAttentions,
|
||||
|
@ -205,12 +206,13 @@ class BloomPipelineForwards:
|
|||
alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
|
||||
|
||||
# causal_mask is constructed every stage and its input is passed through different stages
|
||||
causal_mask = self._prepare_attn_mask(
|
||||
causal_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask,
|
||||
input_shape=(batch_size, seq_length),
|
||||
inputs_embeds=hidden_states,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
|
||||
causal_mask = causal_mask.bool()
|
||||
# split the input tensor along sequence dimension
|
||||
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
||||
if shard_config and shard_config.enable_sequence_parallelism:
|
||||
|
@ -227,21 +229,15 @@ class BloomPipelineForwards:
|
|||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
outputs = self._gradient_checkpointing_func(
|
||||
block.__call__,
|
||||
hidden_states,
|
||||
alibi,
|
||||
causal_mask,
|
||||
layer_past,
|
||||
head_mask[i],
|
||||
use_cache,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
outputs = block(
|
||||
|
@ -1002,11 +998,13 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||
|
||||
alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
|
||||
|
||||
causal_mask = self._prepare_attn_mask(
|
||||
causal_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask,
|
||||
input_shape=(batch_size, seq_length),
|
||||
inputs_embeds=hidden_states,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
causal_mask = causal_mask.bool()
|
||||
# split the input tensor along sequence dimension
|
||||
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
||||
hidden_states = split_forward_gather_backward(
|
||||
|
@ -1018,21 +1016,15 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
outputs = self._gradient_checkpointing_func(
|
||||
block.__call__,
|
||||
hidden_states,
|
||||
alibi,
|
||||
causal_mask,
|
||||
layer_past,
|
||||
head_mask[i],
|
||||
use_cache,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
outputs = block(
|
||||
|
|
|
@ -12,7 +12,6 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
|
|||
from colossalai.shardformer import ShardConfig
|
||||
from colossalai.shardformer.layer import AttnMaskType, ColoAttention
|
||||
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
|
||||
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
|
||||
|
||||
|
||||
def get_flash_core_attention_forward():
|
||||
|
@ -31,7 +30,12 @@ def get_flash_core_attention_forward():
|
|||
device=query_layer.device,
|
||||
)
|
||||
temp_mask = (
|
||||
torch.ones(query_layer.shape[2], key_layer.shape[2], dtype=torch.bool, device=query_layer.device)
|
||||
torch.ones(
|
||||
query_layer.shape[2],
|
||||
key_layer.shape[2],
|
||||
dtype=torch.bool,
|
||||
device=query_layer.device,
|
||||
)
|
||||
.tril(diagonal=0)
|
||||
.expand(query_layer.shape[0], 1, -1, -1)
|
||||
)
|
||||
|
@ -49,6 +53,7 @@ def get_flash_core_attention_forward():
|
|||
attention_mask=attn_bias,
|
||||
attention_mask_type=attention_mask_type,
|
||||
dropout_p=dropout_p,
|
||||
scale=1.0 / self.norm_factor,
|
||||
)
|
||||
context_layer = context_layer.permute(2, 0, 1, 3)
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
|
||||
|
@ -115,7 +120,7 @@ class ChatGLMPipelineForwards:
|
|||
|
||||
@staticmethod
|
||||
def chatglm_model_forward(
|
||||
self: ChatGLMModel,
|
||||
self: "ChatGLMModel",
|
||||
input_ids,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.BoolTensor] = None,
|
||||
|
@ -194,7 +199,9 @@ class ChatGLMPipelineForwards:
|
|||
if shard_config and shard_config.enable_sequence_parallelism:
|
||||
if shard_config.sequence_parallelism_mode == "split_gather":
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group
|
||||
hidden_states,
|
||||
dim=0,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
)
|
||||
for idx in range(start_idx, end_idx):
|
||||
layer = self.encoder._get_layer(idx)
|
||||
|
@ -224,7 +231,9 @@ class ChatGLMPipelineForwards:
|
|||
if shard_config and shard_config.enable_sequence_parallelism:
|
||||
if shard_config.sequence_parallelism_mode == "split_gather":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group
|
||||
hidden_states,
|
||||
dim=0,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
)
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
@ -254,7 +263,7 @@ class ChatGLMPipelineForwards:
|
|||
|
||||
@staticmethod
|
||||
def chatglm_for_conditional_generation_forward(
|
||||
self: ChatGLMForConditionalGeneration,
|
||||
self: "ChatGLMForConditionalGeneration",
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
|
|
|
@ -1,9 +1,16 @@
|
|||
import math
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from transformers.modeling_attn_mask_utils import (
|
||||
AttentionMaskConverter,
|
||||
_prepare_4d_causal_attention_mask,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
CausalLMOutputWithCrossAttentions,
|
||||
|
@ -99,11 +106,17 @@ def get_tp_falcon_decoder_layer_forward():
|
|||
hidden_states: torch.Tensor,
|
||||
alibi: Optional[torch.Tensor],
|
||||
attention_mask: torch.Tensor,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
if "padding_mask" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||
)
|
||||
residual = hidden_states
|
||||
|
||||
if self.config.new_decoder_architecture:
|
||||
|
@ -117,10 +130,12 @@ def get_tp_falcon_decoder_layer_forward():
|
|||
attention_layernorm_out,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
alibi=alibi,
|
||||
head_mask=head_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attention_output = attn_outputs[0]
|
||||
|
@ -154,87 +169,6 @@ def get_tp_falcon_decoder_layer_forward():
|
|||
return forward
|
||||
|
||||
|
||||
def get_falcon_flash_attention_forward():
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention as me_attention
|
||||
except:
|
||||
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
|
||||
from transformers.models.falcon.modeling_falcon import FalconAttention
|
||||
|
||||
def forward(
|
||||
self: FalconAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
alibi: Optional[torch.Tensor],
|
||||
attention_mask: torch.Tensor,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
||||
num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
|
||||
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
||||
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
|
||||
|
||||
batch_size, query_length, _, _ = query_layer.shape
|
||||
|
||||
query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim)
|
||||
key_layer = key_layer.transpose(1, 2).reshape(
|
||||
batch_size * num_kv_heads,
|
||||
query_length,
|
||||
self.head_dim,
|
||||
)
|
||||
value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim)
|
||||
|
||||
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
|
||||
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
# concatenate along seq_length dimension:
|
||||
# - key: [batch_size * self.num_heads, kv_length, head_dim]
|
||||
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
||||
key_layer = torch.cat((past_key, key_layer), dim=1)
|
||||
value_layer = torch.cat((past_value, value_layer), dim=1)
|
||||
|
||||
_, kv_length, _ = key_layer.shape
|
||||
if use_cache:
|
||||
present = (key_layer, value_layer)
|
||||
else:
|
||||
present = None
|
||||
|
||||
attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
|
||||
|
||||
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim).transpose(1, 2).contiguous()
|
||||
key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim).transpose(1, 2).contiguous()
|
||||
value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
if alibi is not None:
|
||||
attention_mask_float = (
|
||||
attention_mask_float + alibi.view(batch_size, self.num_heads, 1, kv_length) * self.beta
|
||||
)
|
||||
|
||||
batch_size, src_len = query_layer_.size()[0], query_layer_.size()[1]
|
||||
tgt_len = key_layer_.size()[1]
|
||||
attention_mask_float = attention_mask_float.expand(batch_size, self.num_heads, src_len, tgt_len).contiguous()
|
||||
context_layer = me_attention(
|
||||
query_layer_,
|
||||
key_layer_,
|
||||
value_layer_,
|
||||
attn_bias=attention_mask_float,
|
||||
scale=self.inv_norm_factor,
|
||||
p=self.attention_dropout.p,
|
||||
)
|
||||
batch_size, seq_length, _, _ = context_layer.shape
|
||||
context_layer = context_layer.reshape(batch_size, seq_length, -1)
|
||||
|
||||
output_tensor = self.dense(context_layer)
|
||||
|
||||
return output_tensor, present
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
class FalconPipelineForwards:
|
||||
"""
|
||||
This class serves as a micro library for falcon pipeline forwards.
|
||||
|
@ -246,6 +180,7 @@ class FalconPipelineForwards:
|
|||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
|
@ -274,17 +209,6 @@ class FalconPipelineForwards:
|
|||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if past_key_values is None:
|
||||
past_key_values = tuple([None] * len(self.h))
|
||||
else:
|
||||
past_key_values = self._convert_to_rw_cache(past_key_values)
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape batch_size x num_heads x N x N
|
||||
# head_mask has shape n_layer x batch x num_heads x N x N
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
# case: First stage of training
|
||||
if stage_manager.is_first_stage():
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
|
@ -295,16 +219,22 @@ class FalconPipelineForwards:
|
|||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
else:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
batch_size, seq_length = input_shape
|
||||
|
||||
if past_key_values is None:
|
||||
past_key_values = tuple([None] * len(self.h))
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
presents = () if use_cache else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
|
@ -312,22 +242,80 @@ class FalconPipelineForwards:
|
|||
# Compute alibi tensor: check build_alibi_tensor documentation
|
||||
past_key_values_length = 0
|
||||
if past_key_values[0] is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[1] # 1 because RW-cache, not standard format
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=hidden_states.device)
|
||||
else:
|
||||
attention_mask = attention_mask.to(hidden_states.device)
|
||||
past_key_values_length = past_key_values[0][0].shape[-2]
|
||||
|
||||
if self.use_alibi:
|
||||
alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
|
||||
mask = (
|
||||
torch.ones(
|
||||
(batch_size, seq_length + past_key_values_length), device=inputs_embeds.device, dtype=torch.long
|
||||
)
|
||||
if attention_mask is None
|
||||
else attention_mask
|
||||
)
|
||||
alibi = build_alibi_tensor(mask, self.num_heads, dtype=hidden_states.dtype)
|
||||
else:
|
||||
alibi = None
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0)
|
||||
|
||||
causal_mask = self._prepare_attn_mask(
|
||||
attention_mask,
|
||||
input_shape=(batch_size, seq_length),
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
elif self._use_sdpa and not output_attentions:
|
||||
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
||||
# the manual implementation that requires a 4D causal mask in all cases.
|
||||
if alibi is None:
|
||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
)
|
||||
elif head_mask is None:
|
||||
alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:])
|
||||
|
||||
attention_mask_2d = attention_mask
|
||||
# We don't call _prepare_4d_causal_attention_mask_for_sdpa as we need to mask alibi using the 4D attention_mask untouched.
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
||||
)
|
||||
|
||||
# We take care to integrate alibi bias in the attention_mask here.
|
||||
if attention_mask_2d is None:
|
||||
attention_mask = alibi / math.sqrt(self.config.hidden_size // self.num_heads)
|
||||
else:
|
||||
attention_mask = torch.masked_fill(
|
||||
alibi / math.sqrt(self.config.hidden_size // self.num_heads),
|
||||
attention_mask < -1,
|
||||
torch.finfo(alibi.dtype).min,
|
||||
)
|
||||
|
||||
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
|
||||
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
if seq_length > 1:
|
||||
attention_mask = AttentionMaskConverter._unmask_unattended(
|
||||
attention_mask, attention_mask_2d, unmasked_value=0.0
|
||||
)
|
||||
else:
|
||||
# PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case.
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
||||
)
|
||||
else:
|
||||
# 4d mask is passed through the layers
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
||||
)
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape batch_size x num_heads x N x N
|
||||
# head_mask has shape n_layer x batch x num_heads x N x N
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
for i, (block, layer_past) in enumerate(
|
||||
|
@ -337,31 +325,23 @@ class FalconPipelineForwards:
|
|||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
outputs = self._gradient_checkpointing_func(
|
||||
block.__call__,
|
||||
hidden_states,
|
||||
alibi,
|
||||
causal_mask,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
head_mask[i],
|
||||
layer_past,
|
||||
use_cache,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
attention_mask=causal_mask,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask[i],
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
|
@ -382,9 +362,6 @@ class FalconPipelineForwards:
|
|||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if presents is not None:
|
||||
presents = self._convert_cache_to_standard_format(presents, batch_size)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
|
|
|
@ -26,7 +26,6 @@ from colossalai.shardformer.layer._operation import gather_forward_split_backwar
|
|||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from ..layer import cross_entropy_1d
|
||||
from ..layer._operation import gather_forward_split_backward
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
@ -178,11 +177,9 @@ class GPT2PipelineForwards:
|
|||
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||
|
||||
if stage_manager.is_first_stage():
|
||||
if position_ids is not None:
|
||||
position_ids = position_ids.view(-1, input_shape[-1])
|
||||
else:
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(0, input_shape[-1], dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||
position_ids = position_ids.unsqueeze(0)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
|
@ -240,22 +237,16 @@ class GPT2PipelineForwards:
|
|||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, use_cache, output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
outputs = self._gradient_checkpointing_func(
|
||||
block.__call__,
|
||||
hidden_states,
|
||||
None,
|
||||
attention_mask,
|
||||
head_mask[i],
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
use_cache,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
outputs = block(
|
||||
|
@ -397,13 +388,11 @@ class GPT2PipelineForwards:
|
|||
shift_logits,
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=self.lm_head.out_features,
|
||||
)
|
||||
else:
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not shard_config.parallel_output:
|
||||
lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group)
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
@ -1301,12 +1290,12 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
|
||||
shift_labels = shift_labels.view(-1)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
|
||||
shift_logits,
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=self.lm_head.out_features,
|
||||
)
|
||||
|
||||
if not shard_config.parallel_output:
|
||||
lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group)
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
@ -1321,3 +1310,18 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||
)
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_jit_fused_gpt2_mlp_forward():
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP
|
||||
|
||||
from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction
|
||||
|
||||
def forward(self: GPT2MLP, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
|
||||
hidden_states, bias = self.c_fc(hidden_states)
|
||||
hidden_states = JitGeLUFunction.apply(hidden_states, bias)
|
||||
hidden_states = self.c_proj(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
return forward
|
||||
|
|
|
@ -148,11 +148,9 @@ class GPTJPipelineForwards:
|
|||
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||
|
||||
# position id to be assigned not just for the first stage for attn input
|
||||
if position_ids is not None:
|
||||
position_ids = position_ids.view(-1, seq_length)
|
||||
else:
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||
position_ids = position_ids.unsqueeze(0)
|
||||
if stage_manager.is_first_stage():
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
|
@ -201,21 +199,15 @@ class GPTJPipelineForwards:
|
|||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, use_cache, output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
outputs = self._gradient_checkpointing_func(
|
||||
block.__call__,
|
||||
hidden_states,
|
||||
None,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
head_mask[i],
|
||||
use_cache,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
outputs = block(
|
||||
|
@ -627,7 +619,9 @@ def get_gptj_flash_attention_forward():
|
|||
value = torch.cat((past_value, value), dim=-2)
|
||||
|
||||
if use_cache is True:
|
||||
present = (key, value)
|
||||
# Note that this cast is quite ugly, but is not implemented before ROPE as the original codebase keeps the key in float32 all along the computation.
|
||||
# Reference: https://github.com/kingoflolz/mesh-transformer-jax/blob/f8315e3003033b23f21d78361b288953064e0e76/mesh_transformer/layers.py#L128
|
||||
present = (key.to(hidden_states.dtype), value)
|
||||
else:
|
||||
present = None
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ import torch.nn.functional as F
|
|||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
|
@ -16,6 +17,8 @@ from transformers.models.llama.modeling_llama import (
|
|||
LlamaForCausalLM,
|
||||
LlamaForSequenceClassification,
|
||||
LlamaModel,
|
||||
_prepare_4d_causal_attention_mask,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
|
@ -31,13 +34,6 @@ from colossalai.shardformer.shard import ShardConfig
|
|||
|
||||
from ..layer import ColoAttention, cross_entropy_1d
|
||||
|
||||
try:
|
||||
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
|
||||
|
||||
LATEST_VERSION = True
|
||||
except ImportError:
|
||||
LATEST_VERSION = False
|
||||
|
||||
|
||||
class LlamaPipelineForwards:
|
||||
"""
|
||||
|
@ -75,13 +71,13 @@ class LlamaPipelineForwards:
|
|||
# retrieve input_ids and inputs_embeds
|
||||
if stage_manager.is_first_stage():
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
batch_size, seq_length = input_ids.shape[:2]
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
batch_size, seq_length, _ = inputs_embeds.shape[:2]
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
@ -111,11 +107,12 @@ class LlamaPipelineForwards:
|
|||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
position_ids = position_ids.unsqueeze(0)
|
||||
|
||||
# embed positions, for the first stage, hidden_states is the input embeddings,
|
||||
# for the other stages, hidden_states is the output of the previous stage
|
||||
|
@ -123,20 +120,32 @@ class LlamaPipelineForwards:
|
|||
# in this case, attention_mask is a dict rather than a tensor
|
||||
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
|
||||
attention_mask = ColoAttention.prepare_attn_kwargs(
|
||||
mask_shape, hidden_states.dtype, hidden_states.device, q_padding_mask=attention_mask, is_causal=True
|
||||
mask_shape,
|
||||
hidden_states.dtype,
|
||||
hidden_states.device,
|
||||
q_padding_mask=attention_mask,
|
||||
is_causal=True,
|
||||
)
|
||||
else:
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
|
||||
)
|
||||
if LATEST_VERSION:
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
elif self._use_sdpa and not output_attentions:
|
||||
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
||||
# the manual implementation that requires a 4D causal mask in all cases.
|
||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
)
|
||||
else:
|
||||
attention_mask = self._prepare_decoder_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
|
||||
# 4d mask is passed through the layers
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
hidden_states,
|
||||
past_key_values_length,
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
@ -149,7 +158,7 @@ class LlamaPipelineForwards:
|
|||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
next_decoder_cache = None
|
||||
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
num_ckpt_layers = 0
|
||||
|
@ -159,8 +168,10 @@ class LlamaPipelineForwards:
|
|||
if shard_config.gradient_checkpoint_config is not None:
|
||||
num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
|
||||
stage=stage_manager.stage,
|
||||
num_stages=stage_manager.num_stages,
|
||||
num_layers=end_idx - start_idx,
|
||||
model_chunk_id=stage_manager.model_chunk_id if stage_manager.is_interleave else 0,
|
||||
model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0),
|
||||
num_model_chunks=stage_manager.num_model_chunks,
|
||||
)
|
||||
assert num_ckpt_layers <= end_idx - start_idx
|
||||
|
||||
|
@ -168,30 +179,22 @@ class LlamaPipelineForwards:
|
|||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if idx - start_idx < num_ckpt_layers:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, output_attentions, None)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
None,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
@ -199,7 +202,7 @@ class LlamaPipelineForwards:
|
|||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
|
@ -212,7 +215,16 @@ class LlamaPipelineForwards:
|
|||
next_cache = next_decoder_cache if use_cache else None
|
||||
if stage_manager.is_last_stage():
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_cache,
|
||||
all_hidden_states,
|
||||
all_self_attns,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
|
@ -316,7 +328,10 @@ class LlamaPipelineForwards:
|
|||
new_vocab_size = logits.shape[-1]
|
||||
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
|
||||
shift_logits,
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=self.lm_head.out_features,
|
||||
)
|
||||
else:
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
|
@ -455,23 +470,25 @@ class LlamaPipelineForwards:
|
|||
def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size):
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
||||
|
||||
llama_version = 2
|
||||
try:
|
||||
from transformers.models.llama.modeling_llama import repeat_kv
|
||||
except:
|
||||
warnings.warn("using llamav1, llamav1 hasn't repeat_kv function")
|
||||
llama_version = 1
|
||||
|
||||
def forward(
|
||||
self: LlamaAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[dict] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if "padding_mask" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||
)
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if sp_mode in ["split_gather", "ring"]:
|
||||
|
@ -495,21 +512,23 @@ def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size):
|
|||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
if llama_version == 2:
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
|
||||
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
|
||||
|
@ -570,7 +589,10 @@ def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig):
|
|||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
|
@ -584,7 +606,11 @@ def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig):
|
|||
# in this case, attention_mask is a dict rather than a tensor
|
||||
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
|
||||
attention_mask = ColoAttention.prepare_attn_kwargs(
|
||||
mask_shape, hidden_states.dtype, hidden_states.device, q_padding_mask=attention_mask, is_causal=True
|
||||
mask_shape,
|
||||
hidden_states.dtype,
|
||||
hidden_states.device,
|
||||
q_padding_mask=attention_mask,
|
||||
is_causal=True,
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
@ -735,11 +761,13 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
|
||||
new_vocab_size = logits.shape[-1]
|
||||
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
|
||||
shift_logits,
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=self.lm_head.out_features,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
|
@ -913,7 +941,10 @@ def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
|
|||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
|
@ -929,10 +960,12 @@ def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
|
|||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
||||
(batch_size, seq_length_with_past),
|
||||
dtype=torch.bool,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
|
||||
attention_mask = self._prepare_decoder_attention_mask(
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, attention_mask.shape, inputs_embeds, past_key_values_length
|
||||
)
|
||||
|
||||
|
|
|
@ -1,70 +1,608 @@
|
|||
from typing import Optional, Tuple
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from transformers.cache_utils import Cache, DynamicCache
|
||||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
SequenceClassifierOutputWithPast,
|
||||
)
|
||||
from transformers.models.mistral.modeling_mistral import MistralForCausalLM, MistralModel
|
||||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from ..layer import ColoAttention
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_mistral_flash_attention_forward():
|
||||
class MistralForwards:
|
||||
@staticmethod
|
||||
def mistral_model_forward(
|
||||
self: MistralModel,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
if use_cache:
|
||||
logger.warning_once("use_cache=True is not supported for Mistral models at the moment.")
|
||||
use_cache = False
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if stage_manager.is_first_stage():
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
batch_size, seq_length = input_shape
|
||||
device = hidden_states.device
|
||||
|
||||
past_key_values_length = 0
|
||||
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
if attention_mask is not None and self._use_flash_attention_2 and use_cache:
|
||||
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
||||
if is_padding_right:
|
||||
raise ValueError(
|
||||
"You are attempting to perform batched generation with padding_side='right'"
|
||||
" this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
|
||||
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
||||
)
|
||||
|
||||
if shard_config.enable_flash_attention:
|
||||
# in this case, attention_mask is a dict rather than a tensor
|
||||
mask_shape = (batch_size, 1, seq_length, seq_length)
|
||||
attention_mask = ColoAttention.prepare_attn_kwargs(
|
||||
mask_shape,
|
||||
hidden_states.dtype,
|
||||
hidden_states.device,
|
||||
q_padding_mask=attention_mask,
|
||||
is_causal=True,
|
||||
)
|
||||
else:
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
else:
|
||||
# 4d mask is passed through the layers
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
hidden_states,
|
||||
past_key_values_length,
|
||||
sliding_window=self.config.sliding_window,
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
num_ckpt_layers = 0
|
||||
if self.gradient_checkpointing and self.training:
|
||||
num_ckpt_layers = end_idx - start_idx
|
||||
# TODO: We can replace `gradient_checkpointing_enable` fn and initialize a gradient_checkpointing (List[bool]) for each layer
|
||||
if shard_config.gradient_checkpoint_config is not None:
|
||||
num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
|
||||
stage=stage_manager.stage,
|
||||
num_stages=stage_manager.num_stages,
|
||||
num_layers=end_idx - start_idx,
|
||||
model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0),
|
||||
num_model_chunks=stage_manager.num_model_chunks,
|
||||
)
|
||||
assert num_ckpt_layers <= end_idx - start_idx
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if idx - start_idx < num_ckpt_layers:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
layer_outputs[2 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = None
|
||||
if stage_manager.is_last_stage():
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
else:
|
||||
return {"hidden_states": hidden_states}
|
||||
|
||||
@staticmethod
|
||||
def mistral_for_causal_lm_forward(
|
||||
self: MistralForCausalLM,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, MistralForCausalLM
|
||||
|
||||
>>> model = MistralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = MistralForwards.mistral_model_forward(
|
||||
self.model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config,
|
||||
)
|
||||
|
||||
past_key_values = None
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
else:
|
||||
hidden_states = outputs.get("hidden_states")
|
||||
return {"hidden_states": hidden_states}
|
||||
|
||||
@staticmethod
|
||||
def mistral_for_sequence_classification_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
transformer_outputs = MistralForwards.mistral_model_forward(
|
||||
self.model,
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config,
|
||||
)
|
||||
|
||||
if input_ids is not None:
|
||||
batch_size = input_ids.shape[0]
|
||||
elif inputs_embeds is not None:
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
else:
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = transformer_outputs[0]
|
||||
logits = self.score(hidden_states)
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
else:
|
||||
if input_ids is not None:
|
||||
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
|
||||
logits.device
|
||||
)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
labels = labels.to(logits.device)
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
if self.num_labels == 1:
|
||||
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
if not return_dict:
|
||||
output = (pooled_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
else:
|
||||
hidden_states = transformer_outputs.get("hidden_states")
|
||||
return {"hidden_states": hidden_states}
|
||||
|
||||
return SequenceClassifierOutputWithPast(
|
||||
loss=loss,
|
||||
logits=pooled_logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def get_mistral_model_forward_for_flash_attn(shard_config: ShardConfig):
|
||||
logger = logging.get_logger(__name__)
|
||||
assert shard_config.enable_flash_attention, "Flash Attention is not enabled."
|
||||
|
||||
def forward(
|
||||
self: MistralModel,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
|
||||
past_key_values_length = 0
|
||||
|
||||
if use_cache:
|
||||
use_legacy_cache = not isinstance(past_key_values, Cache)
|
||||
if use_legacy_cache:
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
||||
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if attention_mask is not None and self._use_flash_attention_2 and use_cache:
|
||||
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
||||
if is_padding_right:
|
||||
raise ValueError(
|
||||
"You are attempting to perform batched generation with padding_side='right'"
|
||||
" this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
|
||||
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
||||
)
|
||||
if shard_config.enable_flash_attention:
|
||||
# in this case, attention_mask is a dict rather than a tensor
|
||||
mask_shape = (batch_size, 1, seq_length, seq_length)
|
||||
attention_mask = ColoAttention.prepare_attn_kwargs(
|
||||
mask_shape,
|
||||
inputs_embeds.dtype,
|
||||
inputs_embeds.device,
|
||||
q_padding_mask=attention_mask,
|
||||
is_causal=True,
|
||||
)
|
||||
else:
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
else:
|
||||
# 4d mask is passed through the layers
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
sliding_window=self.config.sliding_window,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = None
|
||||
if use_cache:
|
||||
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_mistral_flash_attention_forward(shard_config: ShardConfig):
|
||||
from transformers.models.mistral.modeling_mistral import MistralAttention, apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
|
||||
|
||||
def forward(
|
||||
self: MistralAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if "padding_mask" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||
)
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
|
||||
|
||||
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = (
|
||||
self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
)
|
||||
value_states = (
|
||||
self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
)
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
me_input_shape = (bsz, q_len, self.num_heads, self.head_dim)
|
||||
query_states = query_states.transpose(1, 2).contiguous().view(*me_input_shape)
|
||||
key_states = key_states.transpose(1, 2).contiguous().view(*me_input_shape)
|
||||
value_states = value_states.transpose(1, 2).contiguous().view(*me_input_shape)
|
||||
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
|
||||
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
|
||||
|
||||
flash_attention_mask = None
|
||||
attn_mask_type = AttnMaskType.causal
|
||||
if attention_mask != None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
|
||||
attn_mask_type = AttnMaskType.paddedcausal
|
||||
|
||||
attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
|
||||
attn_output = attention(
|
||||
query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type
|
||||
)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ from typing import List, Optional, Tuple, Union
|
|||
|
||||
import torch
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
|
@ -42,7 +43,7 @@ def _get_attention_mask(
|
|||
is_causal=True,
|
||||
)
|
||||
else:
|
||||
attention_mask = self.decoder._prepare_decoder_attention_mask(
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
hidden_states,
|
||||
|
@ -112,7 +113,7 @@ class OPTPipelineForwards:
|
|||
inputs_embeds = decoder.project_in(inputs_embeds)
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
inputs_embeds.dtype
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
if hidden_states is None:
|
||||
raise ValueError("hidden_states shouldn't be None for intermediate stages.")
|
||||
|
@ -125,12 +126,25 @@ class OPTPipelineForwards:
|
|||
# required mask seq length can be calculated via length of past
|
||||
mask_seq_length = past_key_values_length + seq_length
|
||||
# embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(batch_size, mask_seq_length, device=device)
|
||||
elif attention_mask.shape[1] != mask_seq_length:
|
||||
raise ValueError(
|
||||
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
|
||||
f"{mask_seq_length} (sum of the lengths of current and past inputs)"
|
||||
if self.decoder._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
attention_mask = (
|
||||
torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
|
||||
if attention_mask is None
|
||||
else attention_mask
|
||||
)
|
||||
else:
|
||||
# 4d mask is passed through the layers
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
|
||||
elif attention_mask.shape[1] != mask_seq_length:
|
||||
raise ValueError(
|
||||
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
|
||||
f"{mask_seq_length} (sum of the lengths of current and past inputs)"
|
||||
)
|
||||
causal_attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, input_shape, hidden_states, past_key_values_length
|
||||
)
|
||||
|
||||
if stage_manager.is_first_stage():
|
||||
|
@ -205,20 +219,14 @@ class OPTPipelineForwards:
|
|||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if decoder.gradient_checkpointing and decoder.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, output_attentions, None)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
causal_attention_mask,
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
None,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue