From b397104438581ecd72c9fcc0ab25121a243fb90a Mon Sep 17 00:00:00 2001
From: Yuanchen <70520919+chengeharrison@users.noreply.github.com>
Date: Thu, 7 Dec 2023 14:02:03 +0800
Subject: [PATCH] [Colossal-Llama-2] Add finetuning Colossal-Llama-2 example
(#4878)
* Add finetuning Colossal-Llama-2 example
* Add finetuning Colossal-Llama-2 example 2
* Add finetuning Colossal-Llama-2 example and support NEFTuning
* Add inference example and refine neftune
* Modify readme file
* update the imports
---------
Co-authored-by: Xu Yuanchen
Co-authored-by: Camille Zhong <44392324+Camille7777@users.noreply.github.com>
---
applications/Colossal-LLaMA-2/README.md | 90 +++-
.../colossal_llama2/dataset/conversation.py | 96 +++++
.../dataset/spliced_and_tokenized_dataset.py | 135 +++++-
.../colossal_llama2/utils/neftune_patch.py | 69 +++
.../Colossal-LLaMA-2/inference_example.py | 57 +++
.../prepare_pretrain_dataset.py | 12 +-
.../Colossal-LLaMA-2/prepare_sft_dataset.py | 147 +++++++
.../Colossal-LLaMA-2/train_sft.example.sh | 46 ++
applications/Colossal-LLaMA-2/train_sft.py | 403 ++++++++++++++++++
9 files changed, 1036 insertions(+), 19 deletions(-)
create mode 100644 applications/Colossal-LLaMA-2/colossal_llama2/dataset/conversation.py
create mode 100644 applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py
create mode 100644 applications/Colossal-LLaMA-2/inference_example.py
create mode 100644 applications/Colossal-LLaMA-2/prepare_sft_dataset.py
create mode 100755 applications/Colossal-LLaMA-2/train_sft.example.sh
create mode 100644 applications/Colossal-LLaMA-2/train_sft.py
diff --git a/applications/Colossal-LLaMA-2/README.md b/applications/Colossal-LLaMA-2/README.md
index 1d44c5e76..03793bff4 100644
--- a/applications/Colossal-LLaMA-2/README.md
+++ b/applications/Colossal-LLaMA-2/README.md
@@ -11,7 +11,10 @@
- [Performance Evaluation](#performance-evaluation)
- [Examples](#examples)
- [Training Logs](#training-logs)
- - [Import from Transformers (Inference)](#import-from-transformers-inference)
+ - [Inference](#inference)
+ - [Import from HuggingFace](#import-from-huggingface)
+ - [Import from Modelscope](#import-from-modelscope)
+ - [Quick Start](#quick-start)
- [Usage](#usage)
- [Install](#install)
- [0. Pre-requisite](#0-pre-requisite)
@@ -21,8 +24,14 @@
- [1. Init Tokenizer Preparation](#1-init-tokenizer-preparation)
- [2. Init Model Preparation](#2-init-model-preparation)
- [3. Data Preparation](#3-data-preparation)
+ - [3.1 Data for Pretraining](#31-data-for-pretraining)
+ - [3.2 Data for Supervised Fine-tuning](#32-data-for-supervised-fine-tuning)
- [4. Command Line Arguments for Training](#4-command-line-arguments-for-training)
+ - [4.1 Arguments for Pretraining](#41-arguments-for-pretraining)
+ - [4.2 Arguments for Supervised Fine-tuning](#42-arguments-for-supervised-fine-tuning)
- [5. Running Command](#5-running-command)
+ - [5.1 Command for Pretraining](#51-command-for-pretraining)
+ - [5.2 Command for Supervised Fine-tuning](#52-command-for-supervised-fine-tuning)
- [Technical Insights](#technical-insights)
- [Data](#data)
- [Tokenizer](#tokenizer)
@@ -117,7 +126,8 @@ We also recorded the training logs for the experiment
-### Import from Transformers (Inference)
+### Inference
+#### Import from HuggingFace
To load Colossal-LLaMA-2-7B-base model using Transformers, use the following code:
```Python
from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -135,14 +145,15 @@ pred = model.generate(**inputs,
print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True)[len(input):])
```
+#### Import from Modelscope
You can also load our model using modelscope, use the following code:
```Python
from modelscope import AutoModelForCausalLM, AutoTokenizer, snapshot_download
model_dir = snapshot_download('colossalai/Colossal-LLaMA-2-7b-base', revision='v1.0.1')
tokenizer = AutoTokenizer.from_pretrained(model_dir, device_map="auto", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="auto", trust_remote_code=True).eval()
-generation_kwargs = {"max_new_tokens": 256,
- "top_p": 0.95,
+generation_kwargs = {"max_new_tokens": 256,
+ "top_p": 0.95,
"temperature": 0.3
}
input = '离离原上草,'
@@ -153,6 +164,30 @@ print(tokenizer.decode(output.cpu()[0], skip_special_tokens=True)[len(input):])
```
You can download model weights from [🤗HuggingFace](https://huggingface.co/hpcai-tech/Colossal-LLaMA-2-7b-base) or [👾Modelscope](https://modelscope.cn/models/colossalai/Colossal-LLaMA-2-7b-base/summary).
+#### Quick Start
+You can run [`inference_example.py`](inference_example.py) to quickly start the inference of our base model by loading model weights from HF.
+
+Command to run the script:
+```bash
+python inference_example.py \
+ --model_path "" \
+ --device "cuda:0" \
+ --max_new_tokens 512 \
+ --do_sample True \
+ --temperature 0.3 \
+ --top_k 50 \
+ --top_p 0.95 \
+ --input_txt "YOUR_PROMPT_OR_QUESTION"
+```
+Here is details about CLI arguments:
+* Model path: `--model_path`. HF repo name or local path of the model.
+* Device: `--device`. Set the device.
+* Max new tokens: `--max_new_tokens`. Set maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
+* Do sample: `--do_sample`. Set whether or not to use sampling.
+* Temperature: `--temperature`. Set temperature value.
+* Top_k: `--top_k`. Set top_k value for top-k-filtering.
+* Top_p: `--top_p`. Set top_p value for generation.
+* Input_txt: `--input_txt`. The prompt string input to the model.
## Usage
### Install
@@ -218,6 +253,8 @@ Here is details about CLI arguments:
❗️**Important**: Once you initialize the new model checkpoint, copy your new tokenizer files (`special_tokens_map.json`, `tokenizer.model` and `tokenizer_config.json`) to your new model folder.
#### 3. Data Preparation
+
+##### 3.1 Data for Pretraining
Raw data should be formatted as `jsonl` format. Each data point should have the following fields:
* `source` (str, compulsory): This part is ignored when calculating loss. Default can be empty.
* `target` (str, compulsory): Loss will be calculated.
@@ -250,7 +287,31 @@ Here is details about CLI arguments:
* Max length: `max_length`. Max length of spliced samples. Default value is 4096.
* Number of bins for each category: `num_spliced_dataset_bins`. Number of bins for each category, used for bucket-based training.
+##### 3.2 Data for Supervised Fine-tuning
+We prepare data for supervised fine-tuning in a similar way. The main difference lies in the data format. Each data point should have the following field:
+* `messages` (list, compulsory): This part consists of a conversation between a human and assistant. The length of `messages` can vary and only content from `assistant` is used for calculating loss.
+
+Examples:
+```JSON
+{"messages": [{"from": "human", "content": "What are the three primary colors?"}, {"from": "assistant", "content": "The three primary colors are red, blue, and yellow."}]}
+{"messages": [{"from": "human", "content": "解释个人电脑和服务器之间的区别。"}, {"from": "assistant", "content": "个人电脑和服务器是两种不同类型的计算机系统,它们的主要区别在于用途、硬件配置和性能。 个人电脑,顾名思义,是为个人使用而设计的计算机。它们通常用于日常的工作、娱乐和学习,可以运行各种各样的应用程序和游戏。个人电脑的硬件配置一般是按照标准配置来设计的,不过也可以根据个人需求进行定制。 而服务器是为了满足大量用户的需求而设计的计算机系统,它们通常用于为用户提供各种网络服务,如网站、电子邮件和文件传输等。服务器通常需要高性能的硬件配置,并且可以承受高负载和长时间的运行。由于服务器需要支持大量用户的访问,它们通常配备多核处理器、大容量内存和大容量硬盘驱动器,以提高系统的运行速度和稳定性。 总之,个人电脑和服务器之间的主要区别在于它们的用途、硬件配置和性能。个人电脑用于个人使用,而服务器用于支持大量用户的访问。服务器的硬件配置通常比个人电脑更高,以保证系统的性能和稳定性。"}]}
+```
+
+Command to convert jsonl dataset to arrow format is similar to the command in [3.1 Data for Pretraining](#31-data-for-pretraining). In `prepare_sft_dataset.py`, we don't concatenate different data samples.
+```
+python prepare_sft_dataset.py.py \
+ --data_input_dirs ",," \
+ --tokenizer_dir "" \
+ --data_cache_dir "jsonl_to_arrow_cache" \
+ --data_jsonl_output_dir "spliced_tokenized_output_jsonl" \
+ --data_arrow_output_dir "spliced_tokenized_output_arrow" \
+ --max_length 4096 \
+ --num_spliced_dataset_bins 10
+```
+
#### 4. Command Line Arguments for Training
+
+##### 4.1 Arguments for Pretraining
You can use `colossalai run` to launch multi-nodes training:
```bash
colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \
@@ -288,7 +349,16 @@ Here is details about CLI arguments:
* Tensor parallelism size: `--tp`. TP size for 3d Parallelism. The default value is 1.
* Zero stage: `--zero`. Zero stage for 3d Parallelism. The default value is 1.
+##### 4.2 Arguments for Supervised Fine-tuning
+We add support for gradient accumulation and NEFTuning for supervised fine-tuning and thus there are two more arguments apart from the arguments listed in [4.1 Arguments for Pretraining](#41-arguments-for-pretraining).
+
+Here is details about CLI arguments:
+* Accumulation steps: `--accumulation_steps`. The default value is `8`.
+* NEFTuning: `--use_neft`. The default value is `False`. It can help improve the performance of chat models.
+
#### 5. Running Command
+
+##### 5.1 Command for Pretraining
An [example bash](train.example.sh) is also provided for the experiment. Here is the steps to run the experiment:
* Create your own hostfile: `cp hostfile.example hostfile`.
* Create your own bash: `cp train.example.sh train.sh`.
@@ -310,6 +380,10 @@ declare -a dataset=(
"/part-00000"
)
```
+
+##### 5.2 Command for Supervised Fine-tuning
+An [example bash](train_sft.example.sh) is provided. The only difference with the command for pretraining is the two arguments (`--accumulation_steps` and `--use_neft`) in the script. You can refer to [4.2 Arguments for Supervised Fine-tuning](#42-arguments-for-supervised-fine-tuning) for more details.
+
## Technical Insights
In order to enhance LLaMA-2's capabilities for understanding and generating Chinese content, The [Colossal-AI](https://github.com/hpcaitech/ColossalAI) team proposes the continuation of pre-training the LLaMA-2 model using both Chinese and English corpora. The overall pipeline can be described as follows:
@@ -416,3 +490,11 @@ Applying the above process to perform knowledge transfer in any field allows for
year={2023}
}
```
+```bibtex
+@article{jain2023neftune,
+ title={NEFTune: Noisy Embeddings Improve Instruction Finetuning},
+ author={Jain, Neel and Chiang, Ping-yeh and Wen, Yuxin and Kirchenbauer, John and Chu, Hong-Min and Somepalli, Gowthami and Bartoldson, Brian R and Kailkhura, Bhavya and Schwarzschild, Avi and Saha, Aniruddha and others},
+ journal={arXiv preprint arXiv:2310.05914},
+ year={2023}
+}
+```
diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/dataset/conversation.py b/applications/Colossal-LLaMA-2/colossal_llama2/dataset/conversation.py
new file mode 100644
index 000000000..be27ff7bc
--- /dev/null
+++ b/applications/Colossal-LLaMA-2/colossal_llama2/dataset/conversation.py
@@ -0,0 +1,96 @@
+# Copyright 2023 lm-sys@FastChat
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import dataclasses
+from enum import Enum, auto
+from typing import List
+
+
+class SeparatorStyle(Enum):
+ ADD_BOS_EOS_TOKEN = auto()
+
+
+@dataclasses.dataclass
+class Conversation:
+ system: str
+ roles: List[str]
+ messages: List[List[str]]
+ offset: int
+ sep_style: SeparatorStyle
+ seps: List[str]
+
+ def clear(self):
+ self.messages = []
+
+ def get_prompt(self, length: int = None):
+ if length is None:
+ length = len(self.messages)
+
+ if self.sep_style == SeparatorStyle.ADD_BOS_EOS_TOKEN:
+ ret = self.system
+ for role, message in self.messages[0:length]:
+ if message:
+ ret += role + ": " + self.seps[0] + message + self.seps[1]
+ else:
+ ret += role + ": " + self.seps[0]
+ return ret
+ else:
+ raise ValueError(f"Invalid style: {self.sep_style}")
+
+ def save_prompt(self):
+ if self.sep_style == SeparatorStyle.ADD_BOS_EOS_TOKEN:
+ ret = self.system
+ for role, message in self.messages:
+ if message:
+ ret += role + ": " + self.seps[0] + message + self.seps[1] + "\n"
+ else:
+ ret += role + ": " + self.seps[0]
+ return ret
+ else:
+ raise ValueError(f"Invalid style: {self.sep_style}")
+
+ def append_message(self, role, message):
+ self.messages.append([role, message])
+
+ def copy(self):
+ return Conversation(
+ system=self.system,
+ roles=self.roles,
+ messages=[[x, y] for x, y in self.messages],
+ offset=self.offset,
+ sep_style=self.sep_style,
+ seps=self.seps,
+ )
+
+ def dict(self):
+ return {
+ "system": self.system,
+ "roles": self.roles,
+ "messages": self.messages,
+ "offset": self.offset,
+ "seps": self.seps,
+ }
+
+
+conv = Conversation(
+ system="A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
+ roles=("Human", "Assistant"),
+ messages=[],
+ offset=0,
+ sep_style=SeparatorStyle.ADD_BOS_EOS_TOKEN,
+ seps=["", ""],
+)
+
+default_conversation = conv
diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/dataset/spliced_and_tokenized_dataset.py b/applications/Colossal-LLaMA-2/colossal_llama2/dataset/spliced_and_tokenized_dataset.py
index 0c21f325a..8314941ba 100644
--- a/applications/Colossal-LLaMA-2/colossal_llama2/dataset/spliced_and_tokenized_dataset.py
+++ b/applications/Colossal-LLaMA-2/colossal_llama2/dataset/spliced_and_tokenized_dataset.py
@@ -4,22 +4,29 @@
Splicing multiple pre-tokenized sequence data points
"""
+import bisect
import random
import warnings
from copy import deepcopy
-from datasets import dataset_dict
-from typing import Any, Callable, Dict, Iterable, List, Union, Tuple
+from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
+from datasets import dataset_dict
from torch.utils.data import ConcatDataset, Dataset, IterableDataset
from transformers.models.llama.tokenization_llama import LlamaTokenizer
from transformers.tokenization_utils import PreTrainedTokenizer
+from colossalai.logging import get_dist_logger
+
+from .conversation import Conversation, default_conversation
+
+logger = get_dist_logger()
+
IGNORE_INDEX = -100
DSType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
-def supervised_tokenize(
+def supervised_tokenize_pretrain(
data_point: Dict[str, str], tokenizer: LlamaTokenizer, ignore_index: int = None, max_length: int = 4096
) -> Dict[str, Union[int, str, List[int]]]:
"""
@@ -62,6 +69,121 @@ def supervised_tokenize(
)
+def supervised_tokenize_sft(
+ data_point: Dict[str, str],
+ tokenizer: LlamaTokenizer,
+ conversation_template: Conversation = default_conversation,
+ ignore_index: int = None,
+ max_length: int = 4096,
+) -> Dict[str, Union[int, str, List[int]]]:
+ """
+ A tokenization function to tokenize an original supervised data point as following:
+ {"messages": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]}
+ """
+ assert tokenizer.add_bos_token is False and tokenizer.add_eos_token is False, (
+ "Initially set `tokenizer.add_bos_token` and `tokenizer.add_eos_token` to False, "
+ "add and manually later"
+ )
+
+ assert (
+ tokenizer.bos_token == conversation_template.seps[0] and tokenizer.eos_token == conversation_template.seps[1]
+ ), "`bos_token` and `eos_token` should be the same with `conversation_template.seps`."
+
+ if ignore_index is None:
+ ignore_index = IGNORE_INDEX
+
+ messages = data_point["messages"]
+ template = deepcopy(conversation_template)
+ template.messages = []
+
+ for mess in messages:
+ from_str = mess["from"]
+ if from_str.lower() == "human":
+ from_str = template.roles[0]
+ elif from_str.lower() == "assistant":
+ from_str = template.roles[1]
+ else:
+ raise ValueError(f"Unsupported role {from_str.lower()}")
+
+ template.append_message(from_str, mess["content"])
+
+ if len(template.messages) % 2 != 0:
+ template.messages = template.messages[0:-1]
+
+ # `target_turn_index` is the number of turns which exceeds `max_length - 1` for the first time.
+ turns = [i for i in range(1, len(messages) // 2 + 1)]
+ target_turn_index = bisect.bisect_right(
+ turns,
+ max_length - 1,
+ key=lambda x: len(tokenizer([template.get_prompt(2 * x)], add_special_tokens=False)["input_ids"][0]),
+ )
+
+ # The tokenized length for first turn already exceeds `max_length - 1`.
+ if target_turn_index - 1 < 0:
+ return dict(
+ input_ids=None,
+ labels=None,
+ inputs_decode=None,
+ labels_decode=None,
+ seq_length=None,
+ seq_category=None,
+ )
+
+ target_turn = turns[target_turn_index - 1]
+ prompt = template.get_prompt(2 * target_turn)
+ tokenized = tokenizer([prompt], add_special_tokens=False)["input_ids"][0]
+
+ template.messages = template.messages[0 : 2 * target_turn]
+
+ starts = []
+ ends = []
+ gpt_bos = False if template.messages[0][0] == template.roles[0] else True
+ gpt_eos = False if template.messages[0][0] == template.roles[0] else True
+
+ for i, token_id in enumerate(tokenized):
+ if token_id == tokenizer.bos_token_id:
+ if gpt_bos:
+ starts.append(i)
+ gpt_bos = not gpt_bos
+ elif token_id == tokenizer.eos_token_id:
+ if gpt_eos:
+ ends.append(i)
+ gpt_eos = not gpt_eos
+
+ if len(starts) != target_turn or len(ends) != target_turn:
+ logger.info(
+ "Please check whether the tokenizer add additional `bos_token` and `eos_token`.\n\nOr the original message contains `bos_token` or `eos_token`."
+ )
+ return dict(
+ input_ids=None,
+ labels=None,
+ inputs_decode=None,
+ labels_decode=None,
+ seq_length=None,
+ seq_category=None,
+ )
+
+ tokenized = [tokenizer.bos_token_id] + tokenized
+ labels = [ignore_index] * len(tokenized)
+ for start, end in zip(starts, ends):
+ labels[start + 1 : end + 2] = tokenized[start + 1 : end + 2]
+
+ labels_decode = deepcopy(labels)
+ for i, z in enumerate(labels_decode):
+ if z == ignore_index:
+ labels_decode[i] = tokenizer.unk_token_id
+
+ # `inputs_decode` and `labels_decode` can be used to check whether the tokenization method is true.
+ return dict(
+ input_ids=tokenized,
+ labels=labels,
+ inputs_decode=tokenizer.decode(tokenized),
+ labels_decode=tokenizer.decode(labels_decode),
+ seq_length=len(tokenized),
+ seq_category=data_point["category"] if "category" in data_point else "None",
+ )
+
+
class ClosedToConstantLengthSplicedDataset(IterableDataset):
"""
Define an iterable dataset that returns a (close to) constant length data point spliced from multiple
@@ -169,12 +291,7 @@ class ClosedToConstantLengthSplicedDataset(IterableDataset):
spliced_labels.extend(seq_labels)
# For residual spliced data point at the end of the data set
if self.infinite is False and more_data_points is False and len(spliced_input_ids) > 0:
- examples.append(
- {
- self.input_ids_field: spliced_input_ids,
- self.labels_field: spliced_labels
- }
- )
+ examples.append({self.input_ids_field: spliced_input_ids, self.labels_field: spliced_labels})
if self.shuffle:
random.shuffle(examples)
for spliced_data_point in examples:
diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py b/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py
new file mode 100644
index 000000000..079faaace
--- /dev/null
+++ b/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py
@@ -0,0 +1,69 @@
+# Copyright 2023 The Hugging Face team
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+
+
+def unwrap(model):
+ return model.unwrap().module
+
+
+def neftune_post_forward_hook(module, input, output):
+ """
+ Implements the NEFTune forward pass for the model using forward hooks. Note this works only for torch.nn.Embedding
+ layers. This method is slightly adapted from the original source code that can be found here:
+ https://github.com/neelsjain/NEFTune Simply add it to your model as follows:
+ ```python
+ model = ...
+ model.embed_tokens.neftune_noise_alpha = 0.1
+ model.embed_tokens.register_forward_hook(neftune_post_forward_hook)
+ ```
+ Args:
+ module (`torch.nn.Module`):
+ The embedding module where the hook is attached. Note that you need to set `module.neftune_noise_alpha` to
+ the desired noise alpha value.
+ input (`torch.Tensor`):
+ The input tensor to the model.
+ output (`torch.Tensor`):
+ The output tensor of the model (i.e. the embeddings).
+ """
+ if module.training:
+ dims = torch.tensor(output.size(1) * output.size(2))
+ mag_norm = module.neftune_noise_alpha / torch.sqrt(dims)
+ output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm)
+ return output
+
+
+def activate_neftune(model, neftune_noise_alpha=0.1):
+ r"""
+ Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper:
+ https://arxiv.org/abs/2310.05914
+ """
+ embeddings = unwrap(model).get_input_embeddings()
+
+ embeddings.neftune_noise_alpha = neftune_noise_alpha
+ hook_handle = embeddings.register_forward_hook(neftune_post_forward_hook)
+ neftune_hook_handle = hook_handle
+
+ return model, neftune_hook_handle
+
+
+def deactivate_neftune(model, neftune_hook_handle):
+ """
+ Deactivates the neftune method. Make sure to call `_activate_neftune` first.
+ """
+ embeddings = unwrap(model).get_input_embeddings()
+
+ neftune_hook_handle.remove()
+ del embeddings.neftune_noise_alpha
diff --git a/applications/Colossal-LLaMA-2/inference_example.py b/applications/Colossal-LLaMA-2/inference_example.py
new file mode 100644
index 000000000..7fe2d92ab
--- /dev/null
+++ b/applications/Colossal-LLaMA-2/inference_example.py
@@ -0,0 +1,57 @@
+import argparse
+import os
+
+import torch
+from colossalai.logging import get_dist_logger
+from transformers import AutoTokenizer, AutoModelForCausalLM
+
+logger = get_dist_logger()
+
+
+def load_model(model_path, device="cuda", **kwargs):
+ logger.info(
+ "Please check whether the tokenizer and model weights are properly stored in the same folder."
+ )
+ model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs)
+ model.to(device)
+
+ try:
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
+ except OSError:
+ raise ImportError("Tokenizer not found. Please check if the tokenizer exists or the model path is correct.")
+
+ return model, tokenizer
+
+
+@torch.inference_mode()
+def generate(args):
+ model, tokenizer = load_model(model_path=args.model_path, device=args.device)
+
+ BASE_INFERENCE_SUFFIX = "\n\n->\n\n"
+ input_txt = f"{args.input_txt}{BASE_INFERENCE_SUFFIX}"
+
+ inputs = tokenizer(args.input_txt, return_tensors='pt').to(args.device)
+ output = model.generate(**inputs,
+ max_new_tokens=args.max_new_tokens,
+ do_sample=args.do_sample,
+ temperature=args.temperature,
+ top_k=args.top_k,
+ top_p=args.top_p,
+ num_return_sequences=1)
+ response = tokenizer.decode(output.cpu()[0], skip_special_tokens=True)[len(input_txt):]
+ logger.info(f"Question: {input_txt} \n\n Answer: \n{response}")
+ return response
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Colossal-LLaMA-2 inference Process.")
+ parser.add_argument('--model_path', type=str, default="hpcai-tech/Colossal-LLaMA-2-7b-base", help="HF repo name or local path of the model")
+ parser.add_argument('--device', type=str, default="cuda:0", help="Set the device")
+ parser.add_argument('--max_new_tokens', type=int, default=512, help=" Set maximum numbers of tokens to generate, ignoring the number of tokens in the prompt")
+ parser.add_argument('--do_sample', type=bool, default=True, help="Set whether or not to use sampling")
+ parser.add_argument('--temperature', type=float, default=0.3, help="Set temperature value")
+ parser.add_argument('--top_k', type=int, default=50, help="Set top_k value for top-k-filtering")
+ parser.add_argument('--top_p', type=int, default=0.95, help="Set top_p value for generation")
+ parser.add_argument('--input_txt', type=str, default="明月松间照,", help="The prompt input to the model")
+ args = parser.parse_args()
+ generate(args)
\ No newline at end of file
diff --git a/applications/Colossal-LLaMA-2/prepare_pretrain_dataset.py b/applications/Colossal-LLaMA-2/prepare_pretrain_dataset.py
index a519232f6..cb578b5f6 100644
--- a/applications/Colossal-LLaMA-2/prepare_pretrain_dataset.py
+++ b/applications/Colossal-LLaMA-2/prepare_pretrain_dataset.py
@@ -11,14 +11,14 @@ import os
import time
from multiprocessing import cpu_count
+from colossal_llama2.dataset.spliced_and_tokenized_dataset import (
+ ClosedToConstantLengthSplicedDataset,
+ supervised_tokenize_pretrain,
+)
from datasets import dataset_dict, load_dataset
from transformers.models.llama.tokenization_llama import LlamaTokenizer
from colossalai.logging import get_dist_logger
-from colossal_llama2.dataset.spliced_and_tokenized_dataset import (
- supervised_tokenize,
- ClosedToConstantLengthSplicedDataset,
-)
logger = get_dist_logger()
@@ -104,7 +104,7 @@ def main():
assert isinstance(dataset, dataset_dict.Dataset)
logger.info(f"Start to process part-{index}/{len(list_dataset)} of all original datasets.")
dataset = dataset.map(
- function=supervised_tokenize,
+ function=supervised_tokenize_pretrain,
fn_kwargs={"tokenizer": tokenizer, "max_length": args.max_length},
keep_in_memory=False,
num_proc=min(len(dataset), cpu_count()),
@@ -149,5 +149,5 @@ def main():
spliced_dataset.save_to_disk(dataset_path=output_arrow_path, num_proc=min(len(spliced_dataset), cpu_count()))
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/applications/Colossal-LLaMA-2/prepare_sft_dataset.py b/applications/Colossal-LLaMA-2/prepare_sft_dataset.py
new file mode 100644
index 000000000..6d19cbd72
--- /dev/null
+++ b/applications/Colossal-LLaMA-2/prepare_sft_dataset.py
@@ -0,0 +1,147 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+Prepare sft dataset for fine-tuning
+"""
+
+import argparse
+import json
+import math
+import os
+from multiprocessing import cpu_count
+
+from colossal_llama2.dataset.conversation import default_conversation
+from colossal_llama2.dataset.spliced_and_tokenized_dataset import supervised_tokenize_sft
+from datasets import dataset_dict, load_dataset
+from transformers.models.llama.tokenization_llama import LlamaTokenizer
+
+from colossalai.logging import get_dist_logger
+
+logger = get_dist_logger()
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--data_input_dirs",
+ type=str,
+ required=True,
+ default=None,
+ help="Comma(i.e., ',') separated list of all data directories containing `.jsonl` data files.",
+ )
+ parser.add_argument(
+ "--tokenizer_dir", type=str, required=True, default=None, help="A directory containing the tokenizer"
+ )
+ parser.add_argument("--data_cache_dir", type=str, default="cache", help="Data cache directory")
+ parser.add_argument(
+ "--data_jsonl_output_dir",
+ type=str,
+ default="jsonl_output",
+ help="Output directory of spliced dataset with jsonl format",
+ )
+ parser.add_argument(
+ "--data_arrow_output_dir",
+ type=str,
+ default="arrow_output",
+ help="Output directory of spliced dataset with arrow format",
+ )
+ parser.add_argument("--max_length", type=int, default=4096, help="Max length of each spliced tokenized sequence")
+ parser.add_argument("--num_spliced_dataset_bins", type=int, default=10, help="Number of spliced dataset bins")
+ args = parser.parse_args()
+
+ if args.num_spliced_dataset_bins >= 100000:
+ raise ValueError("Too many spliced divisions, must be smaller than 100000")
+
+ assert not os.path.exists(args.data_cache_dir), f"Find existed data cache dir {args.data_cache_dir}"
+ assert not os.path.exists(
+ args.data_jsonl_output_dir
+ ), f"Find existed jsonl data output dir {args.data_jsonl_output_dir}"
+ assert not os.path.exists(
+ args.data_arrow_output_dir
+ ), f"Find existed arrow data output dir {args.data_arrow_output_dir}"
+ os.makedirs(args.data_jsonl_output_dir)
+ os.makedirs(args.data_arrow_output_dir)
+
+ # Prepare to all input datasets
+ input_data_paths = []
+ input_data_dirs = args.data_input_dirs.split(",")
+ for ds_dir in input_data_dirs:
+ ds_dir = os.path.abspath(ds_dir)
+ assert os.path.exists(ds_dir), f"Not find data dir {ds_dir}"
+ ds_files = [name for name in os.listdir(ds_dir) if name.endswith(".jsonl")]
+ ds_paths = [os.path.join(ds_dir, name) for name in ds_files]
+ input_data_paths.extend(ds_paths)
+
+ # Prepare to data splitting.
+ train_splits = []
+ split_interval = math.ceil(100 / args.num_spliced_dataset_bins)
+ for i in range(0, 100, split_interval):
+ start = i
+ end = i + split_interval
+ if end > 100:
+ end = 100
+ train_splits.append(f"train[{start}%:{end}%]")
+
+ # Prepare to the tokenizer.
+ tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_dir)
+ tokenizer.add_bos_token = False
+ tokenizer.add_eos_token = False
+ if tokenizer.pad_token is None:
+ tokenizer.pad_token = tokenizer.unk_token
+
+ list_dataset = load_dataset(
+ path="json",
+ data_files=input_data_paths,
+ cache_dir=os.path.join(args.data_cache_dir, "raw"),
+ keep_in_memory=False,
+ split=train_splits,
+ num_proc=cpu_count(),
+ )
+ for index, dataset in enumerate(list_dataset):
+ assert isinstance(dataset, dataset_dict.Dataset)
+ logger.info(f"Start to process part-{index}/{len(list_dataset)} of all original datasets.")
+ dataset = dataset.map(
+ function=supervised_tokenize_sft,
+ fn_kwargs={
+ "tokenizer": tokenizer,
+ "conversation_template": default_conversation,
+ "max_length": args.max_length,
+ },
+ keep_in_memory=False,
+ num_proc=min(len(dataset), cpu_count()),
+ )
+
+ dataset = dataset.filter(lambda data: data["labels"] is not None)
+ dataset = dataset.sort(column_names=("seq_category", "seq_length"), reverse=False, keep_in_memory=False)
+
+ # We don't concatenate data samples here.
+ spliced_dataset = dataset
+ # Save each jsonl spliced dataset.
+ output_index = "0" * (5 - len(str(index))) + str(index)
+ output_name = f"part-{output_index}"
+ output_jsonl_path = os.path.join(args.data_jsonl_output_dir, output_name + ".jsonl")
+ # st = time.time()
+ with open(file=output_jsonl_path, mode="w", encoding="utf-8") as fp_writer:
+ spliced_count = 0
+ for spliced_data_point in spliced_dataset:
+ if spliced_count % 500 == 0:
+ logger.info(f"processing {spliced_count} spliced data points for {fp_writer.name}")
+ spliced_count += 1
+ fp_writer.write(json.dumps(spliced_data_point, ensure_ascii=False) + "\n")
+
+ # Save each arrow spliced dataset
+ output_arrow_path = os.path.join(args.data_arrow_output_dir, output_name)
+ logger.info(f"Start to save {output_arrow_path}")
+ spliced_dataset = load_dataset(
+ path="json",
+ data_files=[output_jsonl_path],
+ cache_dir=os.path.join(args.data_cache_dir, "spliced_and_tokenized"),
+ keep_in_memory=False,
+ num_proc=cpu_count(),
+ split="train",
+ )
+ spliced_dataset.save_to_disk(dataset_path=output_arrow_path, num_proc=min(len(spliced_dataset), cpu_count()))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/applications/Colossal-LLaMA-2/train_sft.example.sh b/applications/Colossal-LLaMA-2/train_sft.example.sh
new file mode 100755
index 000000000..dcb11515d
--- /dev/null
+++ b/applications/Colossal-LLaMA-2/train_sft.example.sh
@@ -0,0 +1,46 @@
+#!/bin/bash
+
+# NCCL IB environment variables
+export NCCL_IB_HCA=mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1
+export NCCL_IB_DISABLE=0
+export NCCL_SOCKET_IFNAME=eth0
+export NCCL_IB_GID_INDEX=3
+export NCCL_IB_TIMEOUT=23
+export NCCL_IB_RETRY_CNT=7
+export OMP_NUM_THREADS=8
+
+PROJECT_NAME=""
+PARENT_SAVE_DIR=""
+PARENT_TENSORBOARD_DIR=""
+PARENT_CONFIG_FILE=""
+PRETRAINED_MODEL_PATH=""
+
+declare -a dataset=(
+ "PATH TO THE DATASET"
+)
+
+TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
+FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
+SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
+TENSORBOARD_DIR="${PARENT_TENSORBOARD_DIR}${FULL_PROJECT_NAME}"
+CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json"
+
+colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train_sft.py \
+ --pretrained $PRETRAINED_MODEL_PATH \
+ --dataset ${dataset[@]} \
+ --plugin "zero2" \
+ --save_interval 400 \
+ --save_dir $SAVE_DIR \
+ --tensorboard_dir $TENSORBOARD_DIR \
+ --config_file $CONFIG_FILE \
+ --num_epochs 1 \
+ --accumulation_steps 8 \
+ --micro_batch_size 8 \
+ --lr 5e-5 \
+ --mixed_precision "bf16" \
+ --grad_clip 1.0 \
+ --weight_decay 0.01 \
+ --warmup_steps 100 \
+ --use_grad_checkpoint \
+ --use_flash_attn \
+ --use_neft \
diff --git a/applications/Colossal-LLaMA-2/train_sft.py b/applications/Colossal-LLaMA-2/train_sft.py
new file mode 100644
index 000000000..fd9e1cd3e
--- /dev/null
+++ b/applications/Colossal-LLaMA-2/train_sft.py
@@ -0,0 +1,403 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+Supervised fine-tuning of Colossal-LLaMA-2-base developed by Colossal-AI Team
+"""
+
+import argparse
+import json
+import os
+import resource
+from contextlib import nullcontext
+
+import torch
+import torch.distributed as dist
+from colossal_llama2.dataset.loader import (
+ DataCollatorForSupervisedDataset,
+ StatefulDistributedSampler,
+ load_tokenized_dataset,
+ setup_distributed_dataloader,
+)
+from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
+from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
+from colossal_llama2.utils.froze import freeze_non_embeds_parameters
+from colossal_llama2.utils.neftune_patch import activate_neftune, deactivate_neftune
+from torch.utils.tensorboard import SummaryWriter
+from tqdm import tqdm
+from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
+
+import colossalai
+from colossalai.booster import Booster
+from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
+from colossalai.cluster import DistCoordinator
+from colossalai.lazy import LazyInitContext
+from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
+from colossalai.nn.optimizer import HybridAdam
+from colossalai.utils import get_current_device
+
+
+def get_model_numel(model: torch.nn.Module) -> int:
+ return sum(p.numel() for p in model.parameters())
+
+
+def format_numel_str(numel: int) -> str:
+ B = 1024**3
+ M = 1024**2
+ K = 1024
+ if numel >= B:
+ return f"{numel / B:.2f} B"
+ elif numel >= M:
+ return f"{numel / M:.2f} M"
+ elif numel >= K:
+ return f"{numel / K:.2f} K"
+ else:
+ return f"{numel}"
+
+
+def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
+ dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
+ tensor.div_(dist.get_world_size())
+ return tensor
+
+
+def main() -> None:
+ # ==============================
+ # Parse Arguments
+ # ==============================
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--pretrained",
+ type=str,
+ default=None,
+ help="Address of the pre-trained modeling",
+ )
+ parser.add_argument("--dataset", nargs="+", default=[])
+ parser.add_argument(
+ "--plugin",
+ type=str,
+ default="gemini",
+ choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
+ help="Choose which plugin to use",
+ )
+ parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint")
+ parser.add_argument("--save_interval", type=int, default=1000, help="Save interval")
+ parser.add_argument("--save_dir", type=str, default="checkpoint_dir", help="Checkpoint directory")
+ parser.add_argument("--tensorboard_dir", type=str, default="logs_dir", help="Tensorboard directory")
+ parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
+ parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
+ parser.add_argument("--accumulation_steps", type=int, default=8, help="Number of accumulation steps")
+ parser.add_argument("--micro_batch_size", type=int, default=2, help="Batch size of each process")
+ parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
+ parser.add_argument("--max_length", type=int, default=4096, help="Model max length")
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default="fp16",
+ choices=["fp16", "bf16"],
+ help="Mixed precision",
+ )
+ parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
+ parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
+ parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
+ parser.add_argument(
+ "--use_grad_checkpoint",
+ action="store_true",
+ default=False,
+ help="Use gradient checkpointing",
+ )
+ parser.add_argument(
+ "--use_flash_attn",
+ action="store_true",
+ default=False,
+ help="Use flash-attention",
+ )
+ parser.add_argument(
+ "--use_neft",
+ action="store_true",
+ default=False,
+ help="Use NEFTune",
+ )
+ parser.add_argument(
+ "--freeze_non_embeds_params",
+ action="store_true",
+ default=False,
+ help="Freeze non embeddings parameters",
+ )
+ parser.add_argument("--tp", type=int, default=1)
+ parser.add_argument("--zero", type=int, default=1)
+ args = parser.parse_args()
+
+ with open(args.config_file, "w") as f:
+ json.dump(args.__dict__, f, indent=4)
+
+ # ==============================
+ # Initialize Distributed Training
+ # ==============================
+ colossalai.launch_from_torch({})
+ coordinator = DistCoordinator()
+
+ # ==============================
+ # Initialize Tensorboard
+ # ==============================
+ if coordinator.is_master():
+ os.makedirs(args.tensorboard_dir, exist_ok=True)
+ writer = SummaryWriter(args.tensorboard_dir)
+
+ # ==============================
+ # Initialize Booster
+ # ==============================
+ if args.plugin == "gemini":
+ plugin = GeminiPlugin(
+ precision=args.mixed_precision,
+ initial_scale=2**16,
+ max_norm=args.grad_clip,
+ )
+ elif args.plugin == "gemini_auto":
+ plugin = GeminiPlugin(
+ precision=args.mixed_precision,
+ placement_policy="auto",
+ initial_scale=2**16,
+ max_norm=args.grad_clip,
+ )
+ elif args.plugin == "zero2":
+ plugin = LowLevelZeroPlugin(
+ stage=2,
+ precision=args.mixed_precision,
+ initial_scale=2**16,
+ max_norm=args.grad_clip,
+ )
+ elif args.plugin == "zero2_cpu":
+ plugin = LowLevelZeroPlugin(
+ stage=2,
+ precision=args.mixed_precision,
+ initial_scale=2**16,
+ cpu_offload=True,
+ max_norm=args.grad_clip,
+ )
+ elif args.plugin == "3d":
+ plugin = HybridParallelPlugin(
+ tp_size=args.tp,
+ pp_size=1,
+ zero_stage=args.zero,
+ max_norm=args.grad_clip,
+ precision=args.mixed_precision,
+ )
+ else:
+ raise ValueError(f"Unknown plugin {args.plugin}")
+
+ booster = Booster(plugin=plugin)
+
+ # ======================================================
+ # Initialize Tokenizer, Dataset, Collator and Dataloader
+ # ======================================================
+ tokenizer = LlamaTokenizer.from_pretrained(args.pretrained)
+ tokenizer.pad_token = tokenizer.eos_token
+ tokenizer.add_bos_token = False
+ tokenizer.add_eos_token = False
+
+ coordinator.print_on_master(f"Configuration file will be saved at: {args.config_file}")
+ coordinator.print_on_master(f"Tensorboard logs will be saved at: {args.tensorboard_dir}")
+ coordinator.print_on_master(f"Model checkpoint will be saved at: {args.save_dir}")
+
+ coordinator.print_on_master(f"Load dataset: {args.dataset}")
+
+ dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train")
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
+ dataloader = setup_distributed_dataloader(
+ dataset=dataset,
+ batch_size=args.micro_batch_size,
+ shuffle=True,
+ drop_last=True,
+ collate_fn=data_collator,
+ )
+ coordinator.print_on_master(
+ f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
+ )
+
+ # ======================================================
+ # Initialize Model, Objective, Optimizer and LR Scheduler
+ # ======================================================
+ init_ctx = (
+ LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
+ )
+ with init_ctx:
+ model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained))
+ # Freeze part of parameters.
+ if args.freeze_non_embeds_params:
+ freeze_non_embeds_parameters(model=model)
+
+ if args.use_grad_checkpoint:
+ model.gradient_checkpointing_enable()
+ coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
+ if args.use_flash_attn:
+ replace_with_flash_attention(model=model)
+ coordinator.print_on_master(msg="Flash-attention enabled successfully")
+
+ model_numel = get_model_numel(model)
+ coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
+
+ optimizer = HybridAdam(
+ model_params=filter(lambda p: p.requires_grad, model.parameters())
+ if args.freeze_non_embeds_params
+ else model.parameters(),
+ lr=args.lr,
+ betas=(0.9, 0.95),
+ weight_decay=args.weight_decay,
+ adamw_mode=True,
+ )
+
+ if args.warmup_steps is None:
+ args.warmup_steps = int(args.num_epochs * 0.025 * (len(dataloader) // args.accumulation_steps))
+ coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
+
+ lr_scheduler = CosineAnnealingWarmupLR(
+ optimizer=optimizer,
+ total_steps=args.num_epochs * (len(dataloader) // args.accumulation_steps),
+ warmup_steps=args.warmup_steps,
+ eta_min=0.1 * args.lr,
+ )
+
+ # Flash attention will be disabled because it does NOT support fp32.
+ default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
+ torch.set_default_dtype(default_dtype)
+ model, optimizer, _, dataloader, lr_scheduler = booster.boost(
+ model=model,
+ optimizer=optimizer,
+ lr_scheduler=lr_scheduler,
+ dataloader=dataloader,
+ )
+
+ torch.set_default_dtype(torch.float)
+
+ if args.load_checkpoint is None:
+ coordinator.print_on_master(f"Load pretrained model checkpoint from {args.pretrained}")
+ booster.load_model(model, args.pretrained, strict=False)
+
+ coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
+ coordinator.print_on_master(
+ f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
+ )
+
+ start_epoch = 0
+ start_step = 0
+ sampler_start_idx = 0
+ if args.load_checkpoint is not None:
+ if "modeling" in args.load_checkpoint:
+ coordinator.print_on_master(f"Continued pretrain from checkpoint {args.load_checkpoint}")
+ booster.load_model(model, args.load_checkpoint)
+ else:
+ coordinator.print_on_master(f"Load model checkpoint from {args.load_checkpoint}")
+ start_epoch, start_step, sampler_start_idx = load_checkpoint(
+ load_dir=args.load_checkpoint,
+ booster=booster,
+ model=model,
+ optimizer=optimizer,
+ lr_scheduler=lr_scheduler,
+ )
+ coordinator.print_on_master(
+ f"Loaded checkpoint {args.load_checkpoint} at epoch {start_epoch} step {start_step}"
+ )
+ coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
+
+ coordinator.print_on_master(
+ f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
+ )
+ coordinator.print_on_master(
+ f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
+ )
+ coordinator.print_on_master(
+ f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
+ )
+
+ if args.use_neft:
+ coordinator.print_on_master("Activate NEFTune.")
+ model, handle = activate_neftune(model)
+
+ num_steps_per_epoch = len(dataloader) // args.accumulation_steps
+ # If resume training, set the sampler start index to the correct value
+ assert isinstance(dataloader.sampler, StatefulDistributedSampler)
+ dataloader.sampler.set_start_index(start_index=sampler_start_idx)
+
+ for epoch in range(start_epoch, args.num_epochs):
+ dataloader.sampler.set_epoch(epoch=epoch)
+ pbar = tqdm(desc=f"Epoch {epoch}", disable=not coordinator.is_master(), total=num_steps_per_epoch)
+ total_loss = torch.tensor(0.0).to(torch.cuda.current_device())
+ for step, batch in enumerate(dataloader):
+ batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
+
+ batch_output = model(**batch)
+
+ loss = batch_output.loss / args.accumulation_steps
+ total_loss += loss.item()
+
+ booster.backward(loss=loss, optimizer=optimizer)
+
+ if (step + 1) % args.accumulation_steps == 0:
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ all_reduce_mean(tensor=total_loss)
+ pbar.set_postfix({"Loss": f"{total_loss.item():.4f}"})
+ if coordinator.is_master():
+ global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps
+ writer.add_scalar(tag="Loss", scalar_value=total_loss.item(), global_step=global_step)
+ writer.add_scalar(
+ tag="Learning Rate",
+ scalar_value=lr_scheduler.get_last_lr()[0],
+ global_step=global_step,
+ )
+ total_loss.fill_(0.0)
+ pbar.update()
+ # Save modeling.
+
+ if (args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0) or (
+ step + 1
+ ) == len(dataloader):
+ coordinator.print_on_master("\nStart saving model checkpoint with running states")
+
+ if args.use_neft:
+ coordinator.print_on_master("Deactivate NEFTune before saving model.")
+ deactivate_neftune(model, handle)
+
+ save_checkpoint(
+ save_dir=args.save_dir,
+ booster=booster,
+ model=model,
+ optimizer=optimizer,
+ lr_scheduler=lr_scheduler,
+ epoch=epoch,
+ step=step + 1,
+ batch_size=args.micro_batch_size,
+ coordinator=coordinator,
+ )
+ coordinator.print_on_master(
+ f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
+ )
+
+ if args.use_neft:
+ coordinator.print_on_master("Activate NEFTune.")
+ model, handle = activate_neftune(model)
+
+ # Delete CUDA cache.
+ # del batch, batch_labels, batch_output, loss
+ torch.cuda.empty_cache()
+
+ # the continue epochs are not resumed, so we need to reset the sampler start index and start step
+ dataloader.sampler.set_start_index(start_index=0)
+ start_step = 0
+
+ if args.use_neft:
+ coordinator.print_on_master("Deactivate NEFTune.")
+ deactivate_neftune(model, handle)
+
+ # Final save.
+ coordinator.print_on_master("Start saving final model checkpoint")
+ booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
+ coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}")
+
+ coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
+
+
+if __name__ == "__main__":
+ main()