[application] add lora sft example (#6192)

* [application] add lora sft example

* update requirements

* update readme

* update comment

* update ci
pull/6198/head
Hongxin Liu 2025-02-18 13:06:38 +08:00 committed by GitHub
parent d20c8ffd97
commit d54642a263
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 565 additions and 3 deletions

View File

@ -31,13 +31,12 @@ jobs:
- name: Install Colossal-AI
run: |
BUILD_EXT=1 pip install --no-cache-dir -v -e .
pip install --no-cache-dir -v -e .
- name: Install ChatGPT
run: |
cd applications/ColossalChat
pip install --no-cache-dir -v .
export BUILD_EXT=1
pip install --no-cache-dir -r examples/requirements.txt
- name: Install Transformers

View File

@ -29,6 +29,7 @@
- [Alternative Option For RLHF: KTO](#alternative-option-for-rlhf-kahneman-tversky-optimization-kto)
- [O1 Journey](#o1-journey)
- [Inference with Self-refined MCTS](#inference-with-self-refined-mcts)
- [SFT for DeepSeek V3/R1](#sft-for-deepseek-v3)
- [FAQ](#faq)
- [How to save/load checkpoint](#faq)
- [How to train with limited resources](#faq)
@ -389,6 +390,37 @@ You can find more examples in this [repo](https://github.com/XueFuzhao/Instructi
- Cannot abide by OpenAI's policy: When generating prompts from OpenAI API, it always abides by its policy. So no violation case is in the datasets.
</details>
## SFT for DeepSeek V3
We add a script to supervised-fintune the DeepSeek V3/R1 model with LoRA. The script is located in `examples/training_scripts/lora_fintune.py`. The script is similar to the SFT script for Coati7B, but with a few differences. This script is compatible with Peft.
### Dataset preparation
This script receives JSONL format file as input dataset. Each line of dataset should be a list of chat dialogues. E.g.
```json
[{"role": "user", "content": "Hello, how are you?"}, {"role": "assistant", "content": "I'm doing great. How can I help you today?"}]
```
```json
[{"role": "user", "content": "火烧赤壁 曹操为何不拨打119求救"}, {"role": "assistant", "content": "因为在三国时期还没有电话和现代的消防系统所以曹操无法拨打119求救。"}]
```
The dialogues can by multiple turns and it can contain system prompt. For more details, see the [chat_templating](https://huggingface.co/docs/transformers/main/chat_templating).
### Model weights preparation
We use bf16 weights for finetuning. If you downloaded fp8 DeepSeek V3/R1 weights, you can use the [script](https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/fp8_cast_bf16.py) to convert the weights to bf16 via GPU. For Ascend NPU, you can use this [script](https://gitee.com/ascend/ModelZoo-PyTorch/blob/master/MindIE/LLM/DeepSeek/DeepSeek-V2/NPU_inference/fp8_cast_bf16.py).
### Usage
After preparing the dataset and model weights, you can run the script with the following command:
```bash
colossalai run --hostfile path-to-host-file --nproc_per_node 8 lora_finetune.py --pretrained path-to-DeepSeek-R1-bf16 --dataset path-to-dataset.jsonl --plugin moe --lr 2e-5 --max_length 256 -g --ep 8 --pp 3 --batch_size 24 --lora_rank 8 --lora_alpha 16 --num_epochs 2 --warmup_steps 8 --tensorboard_dir logs --save_dir DeepSeek-R1-bf16-lora
```
For more details of each argument, you can run `python lora_finetune.py --help`.
The sample command does not use CPU offload to get better throughput. The minimum hardware requirement for sample command is 32 ascend 910B NPUs (with `ep=8,pp=4`) or 24 H100/H800 GPUs (with `ep=8,pp=3`). If you enable CPU offload by `--zero_cpu_offload`, the hardware requirement can be further reduced.
## FAQ
<details><summary><b>How to save/load checkpoint</b></summary>
@ -501,7 +533,7 @@ Thanks so much to all of our amazing contributors!
- Keep in a sufficiently high running speed
| Model Pair | Alpaca-7B ⚔ Coati-7B | Coati-7B ⚔ Alpaca-7B |
| :-----------: | :------------------: | :------------------: |
|:-------------:|:--------------------:|:--------------------:|
| Better Cases | 38 ⚔ **41** | **45** ⚔ 33 |
| Win Rate | 48% ⚔ **52%** | **58%** ⚔ 42% |
| Average Score | 7.06 ⚔ **7.13** | **7.31** ⚔ 6.82 |

View File

@ -8,6 +8,7 @@ import os
from dataclasses import dataclass
from typing import Dict, Iterator, List, Optional, Sequence, Union
import jsonlines
import torch
import torch.nn.functional as F
from coati.dataset.utils import chuncate_sequence, pad_to_max_len
@ -345,3 +346,77 @@ class StatefulDistributedSampler(DistributedSampler):
def set_start_index(self, start_index: int) -> None:
self.start_index = start_index
def apply_chat_template_and_mask(
tokenizer: PreTrainedTokenizer,
chat: List[Dict[str, str]],
max_length: Optional[int] = None,
padding: bool = True,
truncation: bool = True,
ignore_idx: int = -100,
) -> Dict[str, torch.Tensor]:
tokens = []
assistant_mask = []
for i, msg in enumerate(chat):
msg_tokens = tokenizer.apply_chat_template([msg], tokenize=True)
# remove unexpected bos token
if i > 0 and msg_tokens[0] == tokenizer.bos_token_id:
msg_tokens = msg_tokens[1:]
tokens.extend(msg_tokens)
if msg["role"] == "assistant":
assistant_mask.extend([True] * len(msg_tokens))
else:
assistant_mask.extend([False] * len(msg_tokens))
attention_mask = [1] * len(tokens)
if max_length is not None:
if padding and len(tokens) < max_length:
to_pad = max_length - len(tokens)
if tokenizer.padding_side == "right":
tokens.extend([tokenizer.pad_token_id] * to_pad)
assistant_mask.extend([False] * to_pad)
attention_mask.extend([0] * to_pad)
else:
tokens = [tokenizer.pad_token_id] * to_pad + tokens
assistant_mask = [False] * to_pad + assistant_mask
attention_mask = [0] * to_pad + attention_mask
if truncation and len(tokens) > max_length:
tokens = tokens[:max_length]
assistant_mask = assistant_mask[:max_length]
attention_mask = attention_mask[:max_length]
input_ids = torch.tensor(tokens, dtype=torch.long)
attention_mask = torch.tensor(attention_mask, dtype=torch.long)
labels = input_ids.clone()
labels[~torch.tensor(assistant_mask, dtype=torch.bool)] = ignore_idx
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
class RawConversationDataset(Dataset):
"""
Raw conversation dataset.
Each instance is a dictionary with fields `system`, `roles`, `messages`, `offset`, `sep_style`, `seps`.
"""
def __init__(self, tokenizer: PreTrainedTokenizer, input_file: str, max_length: int) -> None:
self.tokenizer = tokenizer
self.raw_texts = []
with jsonlines.open(input_file) as f:
for line in f:
self.raw_texts.append(line)
self.tokenized_texts = [None] * len(self.raw_texts)
self.max_length = max_length
def __len__(self) -> int:
return len(self.raw_texts)
def __getitem__(self, index: int):
if self.tokenized_texts[index] is None:
message = self.raw_texts[index]
tokens = apply_chat_template_and_mask(self.tokenizer, message, self.max_length)
self.tokenized_texts[index] = dict(tokens)
return self.tokenized_texts[index]

View File

@ -0,0 +1,455 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Supervised fine-tuning of MoE models like Deepseek V3/R1 on a downstream task.
"""
import argparse
import json
import os
import resource
from contextlib import nullcontext
from types import MethodType
import torch
import torch.distributed as dist
from coati.dataset.loader import RawConversationDataset
from peft import LoraConfig
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster
from colossalai.booster.plugin import (
GeminiPlugin,
HybridParallelPlugin,
LowLevelZeroPlugin,
MoeHybridParallelPlugin,
Plugin,
TorchDDPPlugin,
)
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
def all_reduce_mean(loss: torch.Tensor, plugin: Plugin) -> torch.Tensor:
loss = loss.data
group = getattr(plugin, "dp_group", None)
dist.all_reduce(loss, group=group)
return loss / dist.get_world_size(group)
def train(args) -> None:
# ==============================
# Initialize Distributed Training
# ==============================
colossalai.launch_from_torch()
accelerator = get_accelerator()
coordinator = DistCoordinator()
# ==============================
# Initialize Booster
# ==============================
if args.plugin == "ddp":
plugin = TorchDDPPlugin(find_unused_parameters=True if args.use_grad_checkpoint is False else False)
elif args.plugin == "gemini":
plugin = GeminiPlugin(
precision=args.mixed_precision,
initial_scale=2**16,
max_norm=args.grad_clip,
enable_gradient_accumulation=(args.accumulation_steps > 1),
enable_fused_normalization=get_accelerator().is_available(),
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
precision=args.mixed_precision,
placement_policy="auto",
initial_scale=2**16,
max_norm=args.grad_clip,
enable_gradient_accumulation=(args.accumulation_steps > 1),
enable_fused_normalization=get_accelerator().is_available(),
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "zero2":
plugin = LowLevelZeroPlugin(
stage=2,
precision=args.mixed_precision,
initial_scale=2**16,
max_norm=args.grad_clip,
)
elif args.plugin == "zero2_cpu":
plugin = LowLevelZeroPlugin(
stage=2,
precision=args.mixed_precision,
initial_scale=2**16,
cpu_offload=True,
max_norm=args.grad_clip,
)
elif args.plugin == "3d":
plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=args.pp,
sp_size=args.sp,
sequence_parallelism_mode=args.sp_mode,
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
enable_fused_normalization=get_accelerator().is_available(),
enable_sequence_parallelism=args.enable_sequence_parallelism,
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
microbatch_size=args.microbatch_size,
)
elif args.plugin == "moe":
plugin = MoeHybridParallelPlugin(
ep_size=args.ep,
tp_size=args.tp,
pp_size=args.pp,
zero_stage=args.zero_stage,
sp_size=args.sp,
sequence_parallelism_mode=args.sp_mode,
enable_sequence_parallelism=args.sp > 1,
enable_fused_normalization=get_accelerator().is_available(),
enable_flash_attention=args.use_flash_attn,
max_norm=args.grad_clip,
precision=args.mixed_precision,
microbatch_size=args.microbatch_size,
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")
booster = Booster(plugin=plugin)
def is_master():
if isinstance(plugin, HybridParallelPlugin) and plugin.pp_size > 1:
return coordinator.rank == coordinator.world_size - 1
return coordinator.is_master()
# ==============================
# Initialize Tensorboard and Save Config
# ==============================
if is_master():
if args.tensorboard_dir is not None:
from torch.utils.tensorboard import SummaryWriter
os.makedirs(args.tensorboard_dir, exist_ok=True)
writer = SummaryWriter(args.tensorboard_dir)
with open(args.config_file, "w") as f:
json.dump(args.__dict__, f, indent=4)
# ======================================================
# Initialize Tokenizer, Dataset, Collator and Dataloader
# ======================================================
tokenizer = AutoTokenizer.from_pretrained(args.pretrained, trust_remote_code=True)
coordinator.print_on_master(
f"Training Info:\nConfig file: {args.config_file} \nTensorboard logs: {args.tensorboard_dir} \nModel checkpoint: {args.save_dir}"
)
coordinator.print_on_master(f"Load dataset: {args.dataset}")
dataset = RawConversationDataset(
tokenizer,
args.dataset,
args.max_length,
)
dataloader = plugin.prepare_dataloader(
dataset=dataset,
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
)
coordinator.print_on_master(
f"Max device memory after data loader: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB"
)
# ======================================================
# Initialize Model, Objective, Optimizer and LR Scheduler
# ======================================================
# When training the ChatGLM model, LoRA and gradient checkpointing are incompatible.
init_ctx = (
LazyInitContext(default_device=get_current_device())
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin))
else nullcontext()
)
attn_impl = "eager" if get_accelerator().name == "npu" else "flash_attention_2"
config = AutoConfig.from_pretrained(args.pretrained, trust_remote_code=True)
with init_ctx:
# from_pretrained is not compatible with LoRA, we load pretrained weights later.
# model = AutoModelForCausalLM.from_pretrained(
# args.pretrained,
# torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
# trust_remote_code=True,
# attn_implementation=attn_impl,
# )
model = AutoModelForCausalLM.from_config(
config,
trust_remote_code=True,
attn_implementation=attn_impl,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
)
if args.lora_rank > 0:
if model.__class__.__name__.startswith("DeepseekV3"):
lora_config = LoraConfig(
task_type="CAUSAL_LM",
r=args.lora_rank,
lora_alpha=args.lora_alpha,
target_modules=["gate_proj", "up_proj", "down_proj"],
)
else:
lora_config = LoraConfig(task_type="CAUSAL_LM", r=args.lora_rank, lora_alpha=args.lora_alpha)
model = booster.enable_lora(model, lora_config=lora_config)
# this is essential, otherwise the grad checkpoint will not work.
model.train()
if args.use_grad_checkpoint:
model.gradient_checkpointing_enable()
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
if model.config.__class__.__name__.startswith("DeepseekV3"):
model.config.use_cache = False
model.eval()
# enable grad for moe layers
for m in model.modules():
if m.__class__.__name__ == "DeepseekV3MoE":
m.moe_infer = MethodType(m.moe_infer.__wrapped__, m)
model_numel = sum(p.numel() for p in model.parameters())
coordinator.print_on_master(f"Model params: {model_numel / 1e9:.2f} B")
optimizer = HybridAdam(
model_params=model.parameters(),
lr=args.lr,
betas=(0.9, 0.95),
weight_decay=args.weight_decay,
adamw_mode=True,
)
if args.warmup_steps is None:
args.warmup_steps = int(args.num_epochs * 0.025 * (len(dataloader) // args.accumulation_steps))
coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
lr_scheduler = CosineAnnealingWarmupLR(
optimizer=optimizer,
total_steps=args.num_epochs * (len(dataloader) // args.accumulation_steps),
warmup_steps=args.warmup_steps,
eta_min=0.1 * args.lr,
)
# Flash attention will be disabled because it does NOT support fp32.
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
torch.set_default_dtype(default_dtype)
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
dataloader=dataloader,
)
torch.set_default_dtype(torch.float)
booster.load_model(model, args.pretrained)
coordinator.print_on_master(
f"Booster init max device memory: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB"
)
coordinator.print_on_master(
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
)
start_epoch = 0
start_step = 0
num_steps_per_epoch = len(dataloader) // args.accumulation_steps
for epoch in range(start_epoch, args.num_epochs):
dataloader.sampler.set_epoch(epoch=epoch)
if isinstance(plugin, HybridParallelPlugin) and plugin.pp_size > 1:
data_iter = iter(dataloader)
step_bar = tqdm(
range(len(dataloader)),
desc="Step",
disable=not is_master(),
)
for step in step_bar:
outputs = booster.execute_pipeline(
data_iter,
model,
criterion=lambda outputs, inputs: outputs[0],
optimizer=optimizer,
return_loss=True,
)
loss = outputs["loss"]
if booster.plugin.stage_manager.is_last_stage():
global_loss = all_reduce_mean(loss, plugin)
optimizer.step()
if booster.plugin.stage_manager.is_last_stage():
grad_norm = optimizer.get_grad_norm()
step_bar.set_postfix({"loss": global_loss.item(), "grad_norm": grad_norm})
if args.tensorboard_dir is not None and is_master():
global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps
writer.add_scalar(tag="Loss", scalar_value=global_loss.item(), global_step=global_step)
writer.add_scalar(
tag="Learning Rate",
scalar_value=lr_scheduler.get_last_lr()[0],
global_step=global_step,
)
writer.add_scalar(tag="Grad Norm", scalar_value=grad_norm, global_step=global_step)
lr_scheduler.step()
optimizer.zero_grad()
else:
pbar = tqdm(
dataloader,
desc=f"Epoch {epoch}",
disable=not is_master(),
initial=start_step // args.accumulation_steps,
)
total_loss = torch.tensor(0.0, device=get_current_device())
for step, batch in enumerate(pbar, start=start_step // args.accumulation_steps):
batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
batch_output = model(**batch)
loss = batch_output.loss / args.accumulation_steps
total_loss.add_(loss.data)
booster.backward(loss=loss, optimizer=optimizer)
if (step + 1) % args.accumulation_steps == 0:
all_reduce_mean(total_loss, plugin)
optimizer.step()
grad_norm = optimizer.get_grad_norm()
pbar.set_postfix({"loss": total_loss.item(), "grad_norm": grad_norm})
if args.tensorboard_dir is not None and is_master():
global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps
writer.add_scalar(tag="Loss", scalar_value=total_loss.item(), global_step=global_step)
writer.add_scalar(
tag="Learning Rate",
scalar_value=lr_scheduler.get_last_lr()[0],
global_step=global_step,
)
writer.add_scalar(tag="Grad Norm", scalar_value=grad_norm, global_step=global_step)
lr_scheduler.step()
optimizer.zero_grad()
total_loss.fill_(0.0)
# Delete cache.
# del batch, batch_labels, batch_output, loss
accelerator.empty_cache()
# Final save.
coordinator.print_on_master("Start saving final model checkpoint")
if args.lora_rank > 0:
booster.save_lora_as_pretrained(model, os.path.join(args.save_dir, "lora"))
else:
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}")
coordinator.print_on_master(f"Max device memory usage: {accelerator.max_memory_allocated()/1024**2:.2f} MB")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Basic training information.
parser.add_argument(
"-m",
"--pretrained",
type=str,
required=True,
help="Address of the pre-trained model",
)
parser.add_argument("-d", "--dataset", type=str, required=True, help="Raw Jonl dataset for training.")
parser.add_argument(
"-p",
"--plugin",
type=str,
default="zero2",
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d", "ddp", "moe"],
help="Choose which plugin to use",
)
parser.add_argument("--save_dir", type=str, default="checkpoint_dir", help="Checkpoint directory")
parser.add_argument("--tensorboard_dir", type=str, default=None, help="Tensorboard directory")
parser.add_argument("--config_file", type=str, default="training_config.json", help="Config file")
# Training parameters
parser.add_argument("-n", "--num_epochs", type=int, default=1, help="Number of training epochs")
parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of accumulation steps")
parser.add_argument("--batch_size", type=int, default=2, help="Global Batch size of each process")
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
parser.add_argument("--max_length", type=int, default=8192, help="Model max length")
parser.add_argument(
"--mixed_precision",
type=str,
default="bf16",
choices=["fp16", "bf16"],
help="Mixed precision",
)
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
parser.add_argument(
"-g",
"--use_grad_checkpoint",
action="store_true",
default=False,
help="Use gradient checkpointing",
)
parser.add_argument(
"-f",
"--use_flash_attn",
action="store_true",
default=False,
help="Use flash-attention",
)
# Additional arguments for 3d plugin.
parser.add_argument("--tp", type=int, default=1, help="TP size, used for 3d plugin.")
parser.add_argument("--pp", type=int, default=1, help="PP size, used for 3d plugin.")
parser.add_argument("--sp", type=int, default=1, help="SP size, used for 3d plugin.")
parser.add_argument("--ep", type=int, default=1, help="EP size, used for moe plugin.")
parser.add_argument("--zero_stage", type=int, default=1, help="Zero stage, used for 3d plugin.", choices=[0, 1, 2])
parser.add_argument(
"--sp_mode",
type=str,
default="split_gather",
choices=["split_gather", "ring", "all_to_all"],
help="SP mode, used for 3d plugin.",
)
parser.add_argument(
"--enable_sequence_parallelism",
default=False,
action="store_true",
help="Whether to enable SP, used for 3d plugin.",
)
parser.add_argument(
"--zero_cpu_offload", default=False, action="store_true", help="Whether to use offloading, used for 3d plugin."
)
parser.add_argument(
"--microbatch_size", type=int, default=1, help="Batch size for each process in PP, used for 3d plugin."
)
parser.add_argument("--lora_rank", type=int, default=0, help="lora rank when using lora to train.")
parser.add_argument("--lora_alpha", type=int, default=8, help="lora alpha when using lora to train.")
args = parser.parse_args()
if args.plugin in ["3d", "moe"] and args.pp > 1 and args.accumulation_steps > 1:
raise ValueError("Accumulation steps should be 1 when using PP. Please adjust batch size directly.")
train(args)

View File

@ -21,3 +21,4 @@ ninja==1.11.1
sentencepiece==0.1.99
flash-attn
tiktoken
jsonlines