upgrade ppo dpo rm script

pull/5759/head
YeAnbang 2024-05-28 03:04:39 +00:00
parent 7a7e86987d
commit 929e1e3da4
15 changed files with 169 additions and 139 deletions

View File

@ -247,7 +247,7 @@ def apply_rlhf_data_format(
target_turn = int(len(template.messages) / 2)
prompt = template.get_prompt(target_turn * 2)
chunks, require_loss = split_templated_prompt_into_chunks(template.messages[: 2 * target_turn], prompt,
tempalte.end_of_assistant)
template.end_of_assistant)
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
loss_mask = [0] * len(tokenized)
mask_token = tokenizer.eos_token_id or tokenizer.pad_token_id

View File

@ -32,3 +32,9 @@ class Critic(BaseModel):
)
values = self.value_head(sequence_hidden_states).squeeze(-1) # ensure shape is (B, sequence length)
return values
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def get_output_embeddings(self):
return self.model.get_output_embeddings()

View File

@ -36,3 +36,9 @@ class RewardModel(BaseModel):
)
values = self.value_head(sequence_hidden_states).squeeze(-1) # Ensure shape is (B,)
return values
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def get_output_embeddings(self):
return self.model.get_output_embeddings()

View File

@ -1,8 +0,0 @@
{
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true and not '<<SYS>>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\\n\\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don\\'t know the answer to a question, please don\\'t share false information.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>>\\n' + content.strip() + '\\n<</SYS>>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}",
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
"stop_ids": [
2
],
"end_of_assistant": "<|endoftext|>"
}

View File

@ -0,0 +1,9 @@
{
"chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
"stop_ids": [
151645,
151643
],
"end_of_assistant": "<|im_end|>"
}

View File

@ -1,4 +1,4 @@
SAVE_DIR="/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf"
SAVE_DIR="/home/yeanbang/data/experiment/dataset/alpaca/test/Yi-1.5-6B"
rm -rf $SAVE_DIR/cache
rm -rf $SAVE_DIR/jsonl
@ -14,9 +14,9 @@ rm -rf $SAVE_DIR/arrow
python prepare_dataset.py --type sft \
--data_input_dirs /mnt/jfs-hdd/home/yeanbang/data/dataset/sft_data/alpaca/data_preprocessed/train \
--conversation_template_config /home/yeanbang/data/ColossalAI/applications/ColossalChat/config/conversation_template/llama2.json \
--tokenizer_dir "/mnt/jfs-hdd/share/models/Llama-2-7b-chat-hf" \
--data_input_dirs /home/yeanbang/data/experiment/dataset/sft_data/test/sft-data \
--conversation_template_config /home/yeanbang/data/ColossalAI/applications/ColossalChat/config/conversation_template/01-ai_Yi-1.5-9B-Chat.json \
--tokenizer_dir "/mnt/jfs-hdd/share/models/Yi-1.5-6B" \
--data_cache_dir $SAVE_DIR/cache \
--data_jsonl_output_dir $SAVE_DIR/jsonl \
--data_arrow_output_dir $SAVE_DIR/arrow \

View File

@ -56,6 +56,7 @@ def train(args):
initial_scale=2**16,
max_norm=args.grad_clip,
enable_gradient_accumulation=True,
enable_flash_attention=args.use_flash_attn
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
@ -63,6 +64,7 @@ def train(args):
placement_policy="auto",
initial_scale=2**16,
max_norm=args.grad_clip,
enable_flash_attention=args.use_flash_attn
)
elif args.plugin == "zero2":
plugin = LowLevelZeroPlugin(
@ -82,9 +84,15 @@ def train(args):
elif args.plugin == "3d":
plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=1,
zero_stage=0,
pp_size=args.pp,
sp_size=args.sp,
sequence_parallelism_mode=args.sp_mode,
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
enable_sequence_parallelism=True if args.sp > 1 else False,
cpu_offload=True if args.zero_stage>=1 and args.zero_cpu_offload else False,
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
)
else:
@ -172,7 +180,7 @@ def train(args):
shuffle=True,
drop_last=True,
collate_fn=data_collator,
use_tp=args.tp > 1,
tp_size=args.tp,
)
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
@ -290,6 +298,11 @@ if __name__ == "__main__":
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--sp", type=int, default=1)
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--model_type", type=str, default=None)
parser.add_argument("--tokenizer_dir", type=str, default=None)

View File

@ -18,6 +18,7 @@ from coati.models import Critic, RewardModel, convert_to_lora_module, disable_dr
from coati.trainer import PPOTrainer
from coati.utils import load_checkpoint
from transformers import AutoModelForCausalLM, AutoTokenizer
from colossalai.shardformer.policies.auto_policy import get_autopolicy
import colossalai
from colossalai.booster import Booster
@ -86,32 +87,6 @@ def train(args):
disable_dropout(actor)
disable_dropout(critic)
if args.tp > 1:
if reward_model.model.config.architectures[0] != critic.model.config.architectures[0]:
raise ValueError("Reward model and critic model must have the same architecture")
if reward_model.model.config.architectures[0] == "BloomForCausalLM":
from colossalai.shardformer.policies.bloom import BloomPolicy
booster_policy = BloomPolicy()
elif reward_model.model.config.architectures[0] == "LlamaForCausalLM":
from colossalai.shardformer.policies.llama import LlamaPolicy
booster_policy = LlamaPolicy()
elif reward_model.model.config.architectures[0] == "GPT2LMHeadModel":
from colossalai.shardformer.policies.gpt2 import GPT2Policy
booster_policy = GPT2Policy()
elif reward_model.model.config.architectures[0] == "ChatGLMModel":
from colossalai.shardformer.policies.chatglm2 import ChatGLMPolicy
booster_policy = ChatGLMPolicy()
elif reward_model.model.config.architectures[0] == "OPTForCausalLM":
from colossalai.shardformer.policies.opt import OPTPolicy
booster_policy = OPTPolicy()
else:
raise ValueError("Unknown model architecture for policy")
if args.lora_rank > 0:
actor = convert_to_lora_module(actor, args.lora_rank, lora_train_bias=args.lora_train_bias)
critic = convert_to_lora_module(critic, args.lora_rank, lora_train_bias=args.lora_train_bias)
@ -186,7 +161,7 @@ def train(args):
shuffle=True,
drop_last=True,
collate_fn=data_collator,
use_tp=args.tp > 1,
tp_size=args.tp,
)
if len(args.ptx_dataset) > 0:
@ -198,7 +173,7 @@ def train(args):
shuffle=True,
drop_last=True,
collate_fn=data_collator,
use_tp=args.tp > 1,
tp_size=args.tp,
)
else:
train_pretrain_dataloader = None
@ -237,6 +212,7 @@ def train(args):
initial_scale=2**16,
max_norm=args.grad_clip,
enable_gradient_accumulation=True,
enable_flash_attention=args.use_flash_attn
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
@ -244,6 +220,7 @@ def train(args):
placement_policy="auto",
initial_scale=2**16,
max_norm=args.grad_clip,
enable_flash_attention=args.use_flash_attn
)
elif args.plugin == "zero2":
plugin = LowLevelZeroPlugin(
@ -270,11 +247,17 @@ def train(args):
)
custom_plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=1,
zero_stage=0,
pp_size=args.pp,
sp_size=args.sp,
sequence_parallelism_mode=args.sp_mode,
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
enable_sequence_parallelism=True if args.sp > 1 else False,
cpu_offload=True if args.zero_stage>=1 and args.zero_cpu_offload else False,
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
custom_policy=booster_policy,
custom_policy=get_autopolicy(reward_model.model),
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")
@ -474,6 +457,11 @@ if __name__ == "__main__":
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
parser.add_argument("--tokenizer_dir", type=str, default=None)
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--sp", type=int, default=1)
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--rm_pretrain", type=str, default=None)
parser.add_argument("--checkpoint_path", type=str, default=None)

View File

@ -15,7 +15,8 @@ from coati.dataset import (
from coati.models import LogExpLoss, LogSigLoss, RewardModel, convert_to_lora_module
from coati.trainer import RewardModelTrainer
from coati.utils import load_checkpoint
from transformers import AutoTokenizer
from transformers import AutoTokenizer, AutoConfig
from colossalai.shardformer.policies.auto_policy import get_autopolicy
import colossalai
from colossalai.booster import Booster
@ -56,31 +57,10 @@ def train(args):
)
coordinator.print_on_master(msg="Flash-attention enabled successfully")
else:
model = RewardModel(args.pretrain)
if args.tp > 1:
if model.model.config.architectures[0] == "BloomForCausalLM":
from colossalai.shardformer.policies.bloom import BloomPolicy
booster_policy = BloomPolicy()
elif model.model.config.architectures[0] == "LlamaForCausalLM":
from colossalai.shardformer.policies.llama import LlamaPolicy
booster_policy = LlamaPolicy()
elif model.model.config.architectures[0] == "GPT2LMHeadModel":
from colossalai.shardformer.policies.gpt2 import GPT2Policy
booster_policy = GPT2Policy()
elif model.model.config.architectures[0] == "ChatGLMModel":
from colossalai.shardformer.policies.chatglm2 import ChatGLMPolicy
booster_policy = ChatGLMPolicy()
elif model.model.config.architectures[0] == "OPTForCausalLM":
from colossalai.shardformer.policies.opt import OPTPolicy
booster_policy = OPTPolicy()
else:
raise ValueError("Unknown model architecture for policy")
model_config = AutoConfig.from_pretrained(args.pretrain)
model = RewardModel(
args.pretrain,
)
if args.lora_rank > 0:
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
@ -100,6 +80,7 @@ def train(args):
placement_policy="static",
initial_scale=2**16,
max_norm=args.grad_clip,
enable_flash_attention=args.use_flash_attn,
enable_gradient_accumulation=True,
)
elif args.plugin == "gemini_auto":
@ -107,6 +88,7 @@ def train(args):
precision=args.mixed_precision,
placement_policy="auto",
initial_scale=2**16,
enable_flash_attention=args.use_flash_attn,
max_norm=args.grad_clip,
)
elif args.plugin == "zero2":
@ -127,11 +109,17 @@ def train(args):
elif args.plugin == "3d":
plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=1,
zero_stage=0,
pp_size=args.pp,
sp_size=args.sp,
sequence_parallelism_mode=args.sp_mode,
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
enable_sequence_parallelism=True if args.sp > 1 else False,
cpu_offload=True if args.zero_stage>=1 and args.zero_cpu_offload else False,
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
custom_policy=booster_policy,
custom_policy=get_autopolicy(model.model)
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")
@ -189,7 +177,7 @@ def train(args):
shuffle=True,
drop_last=True,
collate_fn=data_collator,
use_tp=args.tp > 1,
tp_size=args.tp,
)
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
@ -307,6 +295,11 @@ if __name__ == "__main__":
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--sp", type=int, default=1)
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--tokenizer_dir", type=str, default=None)
parser.add_argument("--dataset", nargs="+", default=[])

View File

@ -48,29 +48,29 @@ def train(args):
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
trust_remote_code=True)
# check if the hybrid parallel plugin is compatible with the model
# try:
# from colossalai.shardformer.policies.auto_policy import get_autopolicy
# policy = get_autopolicy(model)
# if policy is not None:
# if args.plugin in ['zero2', 'zero2_cpu']:
# # if compatible, set the plugin to hybrid, which use colo-attention
# args.plugin = 'hybrid'
# args.zero_stage = 2
# if args.plugin == 'zero2_cpu':
# args.zero_cpu_offload = True
# else:
# args.zero_cpu_offload = False
# logger.info(f"Model is compatible with hybrid parallel plugin, set plugin to {args.plugin} with zero_stage {args.zero_stage} and zero_cpu_offload {args.zero_cpu_offload}")
# except NotImplementedError:
# logger.warning(f"Unable to find a policy for the model, use {args.plugin} plugin instead")
# if args.use_flash_attn:
# del model
# model = AutoModelForCausalLM.from_pretrained(
# args.pretrain,
# torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
# attn_implementation="flash_attention_2",
# trust_remote_code=True
# )
try:
from colossalai.shardformer.policies.auto_policy import get_autopolicy
policy = get_autopolicy(model)
if policy is not None:
if args.plugin in ['zero2', 'zero2_cpu']:
# if compatible, set the plugin to hybrid, which use colo-attention
args.plugin = '3d'
args.zero_stage = 2
if args.plugin == 'zero2_cpu':
args.zero_cpu_offload = True
else:
args.zero_cpu_offload = False
logger.info(f"Model is compatible with hybrid parallel plugin, set plugin to {args.plugin} with zero_stage {args.zero_stage} and zero_cpu_offload {args.zero_cpu_offload}")
except NotImplementedError:
logger.warning(f"Unable to find a policy for the model, use {args.plugin} plugin instead")
if args.use_flash_attn:
del model
model = AutoModelForCausalLM.from_pretrained(
args.pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
attn_implementation="flash_attention_2",
trust_remote_code=True
)
if args.lora_rank > 0:
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
@ -112,7 +112,7 @@ def train(args):
cpu_offload=True,
max_norm=args.grad_clip,
)
elif args.plugin == "hybrid":
elif args.plugin == "3d":
plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=args.pp,
@ -224,7 +224,6 @@ def train(args):
lr_scheduler=lr_scheduler,
dataloader=train_dataloader,
)
# model = model.to(get_current_device())
torch.set_default_dtype(torch.float)
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
@ -309,7 +308,7 @@ if __name__ == "__main__":
"--plugin",
type=str,
default="gemini",
choices=["gemini", "gemini_auto", "hybrid", "ddp", "zero2_cpu", "zero2"],
choices=["gemini", "gemini_auto", "3d", "ddp", "zero2_cpu", "zero2"],
help="Choose which plugin to use",
)
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")

View File

@ -15,24 +15,24 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
# export CUDA_VISIBLE_DEVICES=4,5,6
set_n_least_used_CUDA_VISIBLE_DEVICES 4
set_n_least_used_CUDA_VISIBLE_DEVICES 2
PROJECT_NAME="sft"
PARENT_SAVE_DIR="/home/yeanbang/data/experiment/output/model" # Path to a folder to save checkpoints
PARENT_TENSORBOARD_DIR="/home/yeanbang/data/experiment/logs/tensorboard" # Path to a folder to save logs
PARENT_CONFIG_FILE="/home/yeanbang/data/experiment/logs/config" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="/mnt/jfs-hdd/share/models/Llama-2-7b-chat-hf" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="/mnt/jfs-hdd/share/models/Llama-2-7b-chat-hf" # huggingface or local tokenizer path
PRETRAINED_MODEL_PATH="/mnt/jfs-hdd/share/models/Yi-1.5-6B" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="/mnt/jfs-hdd/share/models/Yi-1.5-6B" # huggingface or local tokenizer path
declare -a dataset=(
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00000
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00001
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00002
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00003
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00004
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00005
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00006
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00007
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00008
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00009
/home/yeanbang/data/experiment/dataset/alpaca/test/Yi-1.5-6B/arrow/part-00000
/home/yeanbang/data/experiment/dataset/alpaca/test/Yi-1.5-6B/arrow/part-00001
/home/yeanbang/data/experiment/dataset/alpaca/test/Yi-1.5-6B/arrow/part-00002
/home/yeanbang/data/experiment/dataset/alpaca/test/Yi-1.5-6B/arrow/part-00003
/home/yeanbang/data/experiment/dataset/alpaca/test/Yi-1.5-6B/arrow/part-00004
/home/yeanbang/data/experiment/dataset/alpaca/test/Yi-1.5-6B/arrow/part-00005
/home/yeanbang/data/experiment/dataset/alpaca/test/Yi-1.5-6B/arrow/part-00006
/home/yeanbang/data/experiment/dataset/alpaca/test/Yi-1.5-6B/arrow/part-00007
/home/yeanbang/data/experiment/dataset/alpaca/test/Yi-1.5-6B/arrow/part-00008
/home/yeanbang/data/experiment/dataset/alpaca/test/Yi-1.5-6B/arrow/part-00009
)
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
@ -43,7 +43,7 @@ CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
echo $(which colossalai)
echo $(which python)
# the real batch size for gradient descent is number_of_node_in_hostfile * nproc_per_node * train_batch_size
colossalai run --nproc_per_node 4 --master_port 31312 --hostfile ./hostfile train_sft.py \
colossalai run --nproc_per_node 2 --master_port 31312 --hostfile ./hostfile train_sft.py \
--pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--save_interval 4000 \
@ -51,13 +51,13 @@ colossalai run --nproc_per_node 4 --master_port 31312 --hostfile ./hostfile trai
--save_path $SAVE_DIR \
--config_file $CONFIG_FILE \
--lora_rank 0 \
--plugin zero2 \
--tp 1 \
--plugin 3d \
--tp 2 \
--pp 1 \
--zero_stage 2 \
--batch_size 4 \
--zero_stage 0 \
--batch_size 2 \
--max_epochs 3 \
--accumulation_steps 4 \
--accumulation_steps 1 \
--lr 5e-5 \
--max_len 400 \
--grad_checkpoint \

View File

@ -4,5 +4,6 @@
"stop_ids": [
29871,
2
]
}
],
"end_of_assistant": "</s>"
}

View File

@ -9,14 +9,13 @@ model_data_mapping = {
'THUDM/chatglm2-6b': 'THUDM_chatglm2-6b.json',
'THUDM/chatglm3-6b': 'THUDM_chatglm3-6b.json',
'baichuan-inc/Baichuan2-13B-Chat': 'baichuan-inc_Baichuan2-13B-Chat.json',
'Qwen/Qwen-7B-Chat': 'Qwen_Qwen-7B-Chat.json',
'01-ai/Yi-1.5-9B-Chat': '01-ai_Yi-1.5-9B-Chat.json',
'01-ai/Yi-34B': '01-ai_Yi-34B.json',
'deepseek-ai/DeepSeek-V2-Lite': 'deepseek-ai_DeepSeek-V2-Lite.json',
'microsoft/phi-2': 'microsoft_phi-2.json',
'mistralai/Mixtral-8x7B-Instruct-v0.1': 'mistralai_Mixtral-8x7B-Instruct-v0.1.json'
}
chat_template_config_path = '../config/conversation_template'
chat_template_config_path = './config/conversation_template'
def test_tokenization_sft():
@ -34,5 +33,5 @@ def test_tokenization_sft():
)
output = supervised_tokenize_sft({"messages": messages}, tokenizer, conversation_template)
with open(f"./test_data/chat_template/{model_data_mapping[model]}", "r", encoding="utf8") as f:
with open(f"./tests/test_data/chat_template/{model_data_mapping[model]}", "r", encoding="utf8") as f:
assert json.dumps(json.load(f)) == json.dumps(output), f"model: {model} failed"

View File

@ -6,35 +6,59 @@ TEST_DATA_DIR=$BASE_DIR/tests/test_data
DATA_SAVE_PATH=$BASE_TEMP_DIR/tests
CONFIG_DIR=$BASE_DIR/config
MODELS=("colossal-llama2" "llama2" "zephyr" "mistral" "chatGLM2" "Qwen" "Vicuna" "Yi")
MODELS=("colossal-llama2" "llama2" "mistral" "chatGLM2" "chatGLM3" "deepseek" "Yi" "baichuan")
#
get_pretrain() {
local model=$1
if [[ $model == "colossal-llama2" ]]; then
echo "hpcai-tech/Colossal-LLaMA-2-7b-base"
elif [[ $model == "llama2" ]]; then
echo "hf-internal-testing/llama-tokenizer"
elif [[ $model == "zephyr" ]]; then
echo "HuggingFaceH4/zephyr-7b-beta"
elif [[ $model == "phi" ]]; then
echo "microsoft/phi-2"
elif [[ $model == "mistral" ]]; then
echo "mistralai/Mistral-7B-Instruct-v0.2"
echo "mistralai/Mistral-7B-Instruct-v0.3"
elif [[ $model == "chatGLM2" ]]; then
echo "THUDM/chatglm2-6b"
elif [[ $model == "Qwen" ]]; then
echo "Qwen/Qwen-7B-Chat"
elif [[ $model == "Vicuna" ]]; then
echo "lmsys/vicuna-7b-v1.5"
elif [[ $model == "chatGLM3" ]]; then
echo "THUDM/chatglm3-6b"
elif [[ $model == "deepseek" ]]; then
echo "deepseek-ai/DeepSeek-V2-Lite"
elif [[ $model == "Yi" ]]; then
echo "01-ai/Yi-6B-Chat"
echo "01-ai/Yi-1.5-9B-Chat"
elif [[ $model == "baichuan" ]]; then
echo "baichuan-inc/Baichuan2-13B-Chat"
else
echo "Unknown model $model"
exit 1
fi
}
get_conversation_template_config() {
local model=$1
echo "$CONFIG_DIR/conversation_template/$model.json"
if [[ $model == "colossal-llama2" ]]; then
echo "$CONFIG_DIR/conversation_template/colossal-llama2.json"
elif [[ $model == "llama2" ]]; then
echo "$CONFIG_DIR/conversation_template/llama2.json"
elif [[ $model == "deepseek" ]]; then
echo "$CONFIG_DIR/conversation_template/deepseek-ai_DeepSeek-V2-Lite.json"
elif [[ $model == "mistral" ]]; then
echo "$CONFIG_DIR/conversation_template/mistralai_Mixtral-8x7B-Instruct-v0.1.json"
elif [[ $model == "chatGLM2" ]]; then
echo "$CONFIG_DIR/conversation_template/THUDM_chatglm2-6b.json"
elif [[ $model == "chatGLM3" ]]; then
echo "$CONFIG_DIR/conversation_template/THUDM_chatglm3-6b.json"
elif [[ $model == "phi" ]]; then
echo "$CONFIG_DIR/conversation_template/microsoft_phi-2.json"
elif [[ $model == "Yi" ]]; then
echo "$CONFIG_DIR/conversation_template/01-ai_Yi-1.5-9B-Chat.json"
elif [[ $model == "baichuan" ]]; then
echo "$CONFIG_DIR/conversation_template/baichuan-inc_Baichuan2-13B-Chat.json"
else
echo "Unknown model $model"
exit 1
fi
}
# Test SFT data Preparation

View File

@ -30,7 +30,7 @@ MODEL_SAVE_PATH=$TEMP_DIR/rlhf_models
MODELS_DIR=$TEMP_DIR/models_config
# Skip those tests due to CI tests timeout
MODELS=('llama')
PLUGINS=('gemini' 'gemini_auto' 'zero2' 'zero2_cpu' '3d')
PLUGINS=('3d' 'gemini' 'gemini_auto' 'zero2' 'zero2_cpu')
LORA_RANK=('0') # skip to reduce CI execution time, can pass all locally
export OMP_NUM_THREADS=8