InternLM/doc/usage.md

192 lines
11 KiB
Markdown
Raw Normal View History

2023-07-06 04:55:23 +00:00
## 基于InternLM的预训练与微调使用教程
启动一个 Demo 模型训练,需要进行三项准备,**安装****数据集准备**和**模型训练配置**。接下来,首先会介绍数据准备相关的操作,再简要描述模型训练配置相关的内容。
### 安装
请参考[安装文档](./install.md)进行安装。
### 数据准备 (预训练)
InternLM训练任务的数据集包括一系列的`bin`和`meta`文件。使用`tokenizer`从原始文本文件生成训练用数据集。通过在`tools/tokenizer.py`中指定模型参数路径的方式来导入tokenizer模型。目前提供`V7.model`来生成tokens。若想使用不同的模型可直接修改`tokernizer.py`中的模型参数路径。
可以运行以下命令生成原始数据对应的`bin`和`meta`文件,其中参数`raw_data_name`表示原始数据集的文件名称,`input_file_type`表示原始数据集的文件格式,目前支持`txt`、`json`和`jsonl`这三种格式,`bin`表示生成的`bin`文件的保存路径。
```bash
$ python tools/tokenizer.py --raw_data_name your_raw_data_file_name(without suffix) --input_file_type 'text' or 'json' or 'jsonl' --bin your_output_bin_path
```
下面是一个数据处理的例子(这里只给出了`txt`格式的数据处理例子,`json`和`jsonl`的数据处理流程和`txt`的完全一致):
给定一个包含原始数据集的文件`raw_data.txt`,原始数据集如下所示:
```bash
感恩生活中的每一个细节,才能真正体会到幸福的滋味。
梦想是人生的动力源泉,努力追逐,才能实现自己的目标。
学会宽容和理解,才能建立真正和谐的人际关系。
```
可以通过运行以下命令来生成`bin`和`meta`文件:
```bash
$ python tools/tokenizer.py --raw_data_name raw_data --input_file_type 'text' --bin cn/output.bin
```
需要注意的是,生成的`bin`文件需要保存在`cn`或者`en`或者`code`或者`ja`或者`ar`或者`kaoshi`这六个目录下,以区分数据集的类型。
其中,`cn`表示中文数据集;`en`表示英文数据集;`code`表示代码数据集;`ja`表示日语数据集;`ar`表示阿拉伯语数据集;`kaoshi`表示考试数据集。
生成的bin文件的格式如下
```python
{"tokens": [73075, 75302, 69522, 69022, 98899, 67713, 68015, 81269, 74637, 75445, 99157]}
{"tokens": [69469, 60355, 73026, 68524, 60846, 61844, 98899, 67775, 79241, 98899, 67713, 67800, 67453, 67838, 99157]}
{"tokens": [68057, 79017, 60378, 68014, 98899, 67713, 67990, 68015, 70381, 67428, 61003, 67622, 99157]}
```
`bin`文件中的每一行均对应原始数据集中的每一个句子,表示每个句子的`token`下文将用sequence指定
生成的`meta`文件的格式如下:
```bash
(0, 11), (90, 15), (208, 13)
```
在`meta`文件中,每个元组对应着`bin`文件中每一个`sequence`的元信息。其中,元组的第一个元素表示每个`sequence`在所有`sequence`中的`starting index`,第二个元素表示每个`sequence`中有多少个`tokens`。
例如,对于第一个`sequence``starting index`为 0有 11 个`tokens`;对于第二个`sequence`,由于第一个`sequence`转换为`string`后的长度为`89`,因此它的`starting index`为 90有 15 个`tokens`。
`json`和`jsonl`类型的文件的`bin`和`meta`文件格式和`txt`一致,此处不再赘叙。
### 数据准备 (微调)
微调任务的数据集格式与预训练任务保持一致,生成的数据格式为一系列的`bin`和`meta`文件。以下以 Alpaca 数据集为例,介绍微调的数据准备流程。
1. 下载 [Alpaca 数据集](https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data.json)
2. 对 Alpaca 数据进行 tokenize使用以下命令
```shell
python tools/alpaca_tokenizer.py /path/to/alpaca_dataset /path/to/output_dataset /path/to/tokenizer --split_ratio 0.1
```
建议用户参考 alpaca_tokenizer.py 编写新的脚本对自己的数据集进行 tokenize
### 训练配置
以 7B Demo 的配置文件`configs/7B_sft.py`为例,介绍启动一个模型训练所需要进行的数据、模型和并行等相关的配置。
#### 数据配置
数据相关的关键参数配置及释义如下所示:
```python
TRAIN_FOLDER = "/path/to/dataset"
SEQ_LEN = 2048
data = dict(
seq_len=SEQ_LEN, # 数据样本长度,默认值为 2048
micro_num=1, # micro_num 是指在一次模型参数更新中会处理的 micro_batch 的数目,默认值为 1
micro_bsz=1, # packed_length = micro_bsz * SEQ_LEN为一次处理的 micro_batch 的数据大小,默认值为 1
total_steps=50000, # 总的所需执行的 step 的数目,默认值为 50000
min_length=50, # 若数据集文件中数据行数少于50将会被废弃
train_folder=TRAIN_FOLDER, # 数据集文件路径,默认值为 None若 train_folder 为空,则以自动生成的随机数据集进行训练测试
pack_sample_into_one=False, # 数据整理的逻辑,决定是按照 seq_len 维度或者是 sequence 的真实长度来进行attention计算
)
```
<div align="left">
<img src="./imgs/pack_into_one.png" width="550"/>
</div>
目前支持传入数据集文件路径`train_folder`,且要求文件格式如下:
```bash
- folder
- code
train_000.bin
train_000.bin.meta
```
数据集的详细内容可参考``数据准备``模块相关的介绍。
#### 模型配置
如果在启动训练时要加载模型 `checkpoint`,可进行如下相关配置:
```python
SAVE_CKPT_FOLDER = "local:/path/to/save/ckpt"
MODEL_ONLY_FOLDER = "local:/path/to/load/init/model/ckpt"
LOAD_CKPT_FOLDER = "local:/path/to/load/resume/ckpt"
ckpt = dict(
save_ckpt_folder=SAVE_CKPT_FOLDER, # 存储模型和优化器 checkpoint 的路径
checkpoint_every=float("inf"), # 每多少个 step 存储一次 checkpoint默认值为 inf
load_model_only_folder=MODEL_ONLY_FOLDER, # 加载模型初始权重的路径,只加载模型权重,不加载优化器权重,训练将从第一个 step 开始
load_ckpt_folder=LOAD_CKPT_FOLDER, # 断点续训时,加载模型和优化器等权重的路径,将从指定的 step 恢复训练
load_optimizer=True, # 断点续训时,是否需要加载优化器权重,默认值为 True
)
```
注意:
- `load_model_only_folder`与`load_ckpt_folder`不能同时设置
- 路径若以 `local:` 为前缀,则存储在本地文件系统;若以 `boto3:` 为前缀,则存储在远程 oss 上
模型相关关键参数配置如下所示:
```python
model_type = "INTERNLM" # 模型类型,默认值为 "INTERNLM",对应模型结构初始化接口函数
NUM_ATTENTION_HEAD = 32
VOCAB_SIZE = 103168
HIDDEN_SIZE = 4096
NUM_LAYER = 32
MLP_RATIO = 8 / 3
model = dict(
checkpoint=False,
num_attention_heads=NUM_ATTENTION_HEAD,
embed_split_hidden=True,
vocab_size=VOCAB_SIZE,
embed_grad_scale=1,
parallel_output=True,
hidden_size=HIDDEN_SIZE,
num_layers=NUM_LAYER,
mlp_ratio=MLP_RATIO,
apply_post_layer_norm=False,
dtype="torch.bfloat16",
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
)
```
注意:用户可自定义模型类型名和模型结构,并配置相对应的模型参数。通过`utils/registry.py`下的`MODEL_INITIALIZER`对象进行模型初始化函数接口注册,在训练主函数`train.py`中初始化模型时,可通过`model_type`配置获取指定的模型初始化接口函数。
*如果基于 InternLM 7B继续训练可以参考 [ModelZoo](https://github.com/InternLM/InternLM/tree/main#model-zoo) 中 OpenXLab 链接下载权重*
#### 并行配置
训练并行配置样例如下:
```python
parallel = dict(
zero1=8,
pipeline=1,
tensor=1,
)
```
- zero1zero 并行策略,分如下三种情况,默认值为 -1
- 当`size <= 0`,则 zero1 进程组的大小等于数据并行进程组的大小,因此优化器状态参数将在数据并行范围内分配
- 当`size == 1`,则不使用 zero1 ,所有数据并行组保留完整的优化器状态参数
- 当`size > 1`且`size <= data_parallel_world_size`,则 zero1 进程组是数据并行进程组的子集
- pipeline流水线并行大小目前只支持 1默认值为 1
- tensor张量并行大小通常是每个节点的 GPU 数量,默认值为 1
注意:`数据并行大小 = 总的 GPU 数目 / 流水线并行大小 / 张量并行大小`
### 启动训练
完成了以上数据集准备和相关训练配置后,可启动 Demo 训练。接下来分别以 slurm 和 torch 环境为例,介绍训练启动方式。
若在 slurm 上启动分布式运行环境,多节点 16 卡的运行命令如下所示:
```bash
$ srun -p internllm -N 2 -n 16 --ntasks-per-node=8 --gpus-per-task=1 python train.py --config ./configs/7B_sft.py
```
若在 torch 上启动分布式运行环境,单节点 8 卡的运行命令如下所示:
```bash
$ torchrun --nnodes=1 --nproc-per-node=8 train.py --config ./configs/7B_sft.py
```
### 运行结果
以 slurm 上单机 8 卡的 Demo 训练配置为例,训练结果日志展示如下:
```bash
2023-07-04 21:40:14,148 INFO train.py:318 in record_current_batch_training_metrics -- step=17,loss=9.810295104980469,tgs (tokens per gpu per second)=4399.93,lr=3.8e-06,loss_scale=65536.0,grad_norm=4.177205427229359,micro_num=4,num_consumed_tokens=2359296,inf_nan_skip_batches=0,num_samples_in_batch=60,largest_length=1300,largest_batch=18,smallest_batch=13,adam_beta2=0.95,fwd_bwd_time=3.57
2023-07-04 21:40:17,825 INFO train.py:318 in record_current_batch_training_metrics -- step=18,loss=9.715232849121094,tgs (tokens per gpu per second)=4457.7,lr=4.000000000000001e-06,loss_scale=65536.0,grad_norm=5.018154183978863,micro_num=4,num_consumed_tokens=2490368,inf_nan_skip_batches=0,num_samples_in_batch=68,largest_length=1153,largest_batch=19,smallest_batch=16,adam_beta2=0.95,fwd_bwd_time=3.52
2023-07-04 21:40:21,526 INFO train.py:318 in record_current_batch_training_metrics -- step=19,loss=9.76744556427002,tgs (tokens per gpu per second)=4429.13,lr=4.2000000000000004e-06,loss_scale=65536.0,grad_norm=5.245329823265071,micro_num=4,num_consumed_tokens=2621440,inf_nan_skip_batches=0,num_samples_in_batch=70,largest_length=706,largest_batch=18,smallest_batch=17,adam_beta2=0.95,fwd_bwd_time=3.54
2023-07-04 21:40:25,227 INFO train.py:318 in record_current_batch_training_metrics -- step=20,loss=9.628969192504883,tgs (tokens per gpu per second)=4427.46,lr=4.4e-06,loss_scale=65536.0,grad_norm=5.503176552110271,micro_num=4,num_consumed_tokens=2752512,inf_nan_skip_batches=0,num_samples_in_batch=69,largest_length=915,largest_batch=20,smallest_batch=15,adam_beta2=0.95,fwd_bwd_time=3.55
2023-07-04 21:40:28,899 INFO train.py:318 in record_current_batch_training_metrics -- step=21,loss=9.690847396850586,tgs (tokens per gpu per second)=4464.18,lr=4.6e-06,loss_scale=65536.0,grad_norm=5.5336643273197526,micro_num=4,num_consumed_tokens=2883584,inf_nan_skip_batches=0,num_samples_in_batch=66,largest_length=870,largest_batch=17,smallest_batch=16,adam_beta2=0.95,fwd_bwd_time=3.52
2023-07-04 21:40:32,629 INFO train.py:318 in record_current_batch_training_metrics -- step=22,loss=9.61986255645752,tgs (tokens per gpu per second)=4393.28,lr=4.800000000000001e-06,loss_scale=65536.0,grad_norm=9.01168869536059,micro_num=4,num_consumed_tokens=3014656,inf_nan_skip_batches=0,num_samples_in_batch=65,largest_length=1151,largest_batch=20,smallest_batch=14,adam_beta2=0.95,fwd_bwd_time=3.57
```