mirror of https://github.com/hpcaitech/ColossalAI
upgrade ppo dpo rm script
parent
7a7e86987d
commit
929e1e3da4
|
@ -247,7 +247,7 @@ def apply_rlhf_data_format(
|
|||
target_turn = int(len(template.messages) / 2)
|
||||
prompt = template.get_prompt(target_turn * 2)
|
||||
chunks, require_loss = split_templated_prompt_into_chunks(template.messages[: 2 * target_turn], prompt,
|
||||
tempalte.end_of_assistant)
|
||||
template.end_of_assistant)
|
||||
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
|
||||
loss_mask = [0] * len(tokenized)
|
||||
mask_token = tokenizer.eos_token_id or tokenizer.pad_token_id
|
||||
|
|
|
@ -32,3 +32,9 @@ class Critic(BaseModel):
|
|||
)
|
||||
values = self.value_head(sequence_hidden_states).squeeze(-1) # ensure shape is (B, sequence length)
|
||||
return values
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.get_input_embeddings()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.model.get_output_embeddings()
|
|
@ -36,3 +36,9 @@ class RewardModel(BaseModel):
|
|||
)
|
||||
values = self.value_head(sequence_hidden_states).squeeze(-1) # Ensure shape is (B,)
|
||||
return values
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.get_input_embeddings()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.model.get_output_embeddings()
|
||||
|
|
|
@ -1,8 +0,0 @@
|
|||
{
|
||||
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true and not '<<SYS>>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\\n\\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don\\'t know the answer to a question, please don\\'t share false information.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>>\\n' + content.strip() + '\\n<</SYS>>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}",
|
||||
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
|
||||
"stop_ids": [
|
||||
2
|
||||
],
|
||||
"end_of_assistant": "<|endoftext|>"
|
||||
}
|
|
@ -0,0 +1,9 @@
|
|||
{
|
||||
"chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
|
||||
"stop_ids": [
|
||||
151645,
|
||||
151643
|
||||
],
|
||||
"end_of_assistant": "<|im_end|>"
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
SAVE_DIR="/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf"
|
||||
SAVE_DIR="/home/yeanbang/data/experiment/dataset/alpaca/test/Yi-1.5-6B"
|
||||
|
||||
rm -rf $SAVE_DIR/cache
|
||||
rm -rf $SAVE_DIR/jsonl
|
||||
|
@ -14,9 +14,9 @@ rm -rf $SAVE_DIR/arrow
|
|||
|
||||
|
||||
python prepare_dataset.py --type sft \
|
||||
--data_input_dirs /mnt/jfs-hdd/home/yeanbang/data/dataset/sft_data/alpaca/data_preprocessed/train \
|
||||
--conversation_template_config /home/yeanbang/data/ColossalAI/applications/ColossalChat/config/conversation_template/llama2.json \
|
||||
--tokenizer_dir "/mnt/jfs-hdd/share/models/Llama-2-7b-chat-hf" \
|
||||
--data_input_dirs /home/yeanbang/data/experiment/dataset/sft_data/test/sft-data \
|
||||
--conversation_template_config /home/yeanbang/data/ColossalAI/applications/ColossalChat/config/conversation_template/01-ai_Yi-1.5-9B-Chat.json \
|
||||
--tokenizer_dir "/mnt/jfs-hdd/share/models/Yi-1.5-6B" \
|
||||
--data_cache_dir $SAVE_DIR/cache \
|
||||
--data_jsonl_output_dir $SAVE_DIR/jsonl \
|
||||
--data_arrow_output_dir $SAVE_DIR/arrow \
|
||||
|
|
|
@ -56,6 +56,7 @@ def train(args):
|
|||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_gradient_accumulation=True,
|
||||
enable_flash_attention=args.use_flash_attn
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(
|
||||
|
@ -63,6 +64,7 @@ def train(args):
|
|||
placement_policy="auto",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_flash_attention=args.use_flash_attn
|
||||
)
|
||||
elif args.plugin == "zero2":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
|
@ -82,9 +84,15 @@ def train(args):
|
|||
elif args.plugin == "3d":
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=1,
|
||||
zero_stage=0,
|
||||
pp_size=args.pp,
|
||||
sp_size=args.sp,
|
||||
sequence_parallelism_mode=args.sp_mode,
|
||||
zero_stage=args.zero_stage,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
enable_sequence_parallelism=True if args.sp > 1 else False,
|
||||
cpu_offload=True if args.zero_stage>=1 and args.zero_cpu_offload else False,
|
||||
parallel_output=False,
|
||||
max_norm=args.grad_clip,
|
||||
precision=args.mixed_precision,
|
||||
)
|
||||
else:
|
||||
|
@ -172,7 +180,7 @@ def train(args):
|
|||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=data_collator,
|
||||
use_tp=args.tp > 1,
|
||||
tp_size=args.tp,
|
||||
)
|
||||
|
||||
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
|
||||
|
@ -290,6 +298,11 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
|
||||
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
|
||||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--pp", type=int, default=1)
|
||||
parser.add_argument("--sp", type=int, default=1)
|
||||
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
|
||||
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
|
||||
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
|
||||
parser.add_argument("--pretrain", type=str, default=None)
|
||||
parser.add_argument("--model_type", type=str, default=None)
|
||||
parser.add_argument("--tokenizer_dir", type=str, default=None)
|
||||
|
|
|
@ -18,6 +18,7 @@ from coati.models import Critic, RewardModel, convert_to_lora_module, disable_dr
|
|||
from coati.trainer import PPOTrainer
|
||||
from coati.utils import load_checkpoint
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from colossalai.shardformer.policies.auto_policy import get_autopolicy
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
|
@ -86,32 +87,6 @@ def train(args):
|
|||
disable_dropout(actor)
|
||||
disable_dropout(critic)
|
||||
|
||||
if args.tp > 1:
|
||||
if reward_model.model.config.architectures[0] != critic.model.config.architectures[0]:
|
||||
raise ValueError("Reward model and critic model must have the same architecture")
|
||||
if reward_model.model.config.architectures[0] == "BloomForCausalLM":
|
||||
from colossalai.shardformer.policies.bloom import BloomPolicy
|
||||
|
||||
booster_policy = BloomPolicy()
|
||||
elif reward_model.model.config.architectures[0] == "LlamaForCausalLM":
|
||||
from colossalai.shardformer.policies.llama import LlamaPolicy
|
||||
|
||||
booster_policy = LlamaPolicy()
|
||||
elif reward_model.model.config.architectures[0] == "GPT2LMHeadModel":
|
||||
from colossalai.shardformer.policies.gpt2 import GPT2Policy
|
||||
|
||||
booster_policy = GPT2Policy()
|
||||
elif reward_model.model.config.architectures[0] == "ChatGLMModel":
|
||||
from colossalai.shardformer.policies.chatglm2 import ChatGLMPolicy
|
||||
|
||||
booster_policy = ChatGLMPolicy()
|
||||
elif reward_model.model.config.architectures[0] == "OPTForCausalLM":
|
||||
from colossalai.shardformer.policies.opt import OPTPolicy
|
||||
|
||||
booster_policy = OPTPolicy()
|
||||
else:
|
||||
raise ValueError("Unknown model architecture for policy")
|
||||
|
||||
if args.lora_rank > 0:
|
||||
actor = convert_to_lora_module(actor, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
critic = convert_to_lora_module(critic, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
|
@ -186,7 +161,7 @@ def train(args):
|
|||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=data_collator,
|
||||
use_tp=args.tp > 1,
|
||||
tp_size=args.tp,
|
||||
)
|
||||
|
||||
if len(args.ptx_dataset) > 0:
|
||||
|
@ -198,7 +173,7 @@ def train(args):
|
|||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=data_collator,
|
||||
use_tp=args.tp > 1,
|
||||
tp_size=args.tp,
|
||||
)
|
||||
else:
|
||||
train_pretrain_dataloader = None
|
||||
|
@ -237,6 +212,7 @@ def train(args):
|
|||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_gradient_accumulation=True,
|
||||
enable_flash_attention=args.use_flash_attn
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(
|
||||
|
@ -244,6 +220,7 @@ def train(args):
|
|||
placement_policy="auto",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_flash_attention=args.use_flash_attn
|
||||
)
|
||||
elif args.plugin == "zero2":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
|
@ -270,11 +247,17 @@ def train(args):
|
|||
)
|
||||
custom_plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=1,
|
||||
zero_stage=0,
|
||||
pp_size=args.pp,
|
||||
sp_size=args.sp,
|
||||
sequence_parallelism_mode=args.sp_mode,
|
||||
zero_stage=args.zero_stage,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
enable_sequence_parallelism=True if args.sp > 1 else False,
|
||||
cpu_offload=True if args.zero_stage>=1 and args.zero_cpu_offload else False,
|
||||
parallel_output=False,
|
||||
max_norm=args.grad_clip,
|
||||
precision=args.mixed_precision,
|
||||
custom_policy=booster_policy,
|
||||
custom_policy=get_autopolicy(reward_model.model),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
|
@ -474,6 +457,11 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
|
||||
parser.add_argument("--tokenizer_dir", type=str, default=None)
|
||||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--pp", type=int, default=1)
|
||||
parser.add_argument("--sp", type=int, default=1)
|
||||
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
|
||||
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
|
||||
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
|
||||
parser.add_argument("--pretrain", type=str, default=None)
|
||||
parser.add_argument("--rm_pretrain", type=str, default=None)
|
||||
parser.add_argument("--checkpoint_path", type=str, default=None)
|
||||
|
|
|
@ -15,7 +15,8 @@ from coati.dataset import (
|
|||
from coati.models import LogExpLoss, LogSigLoss, RewardModel, convert_to_lora_module
|
||||
from coati.trainer import RewardModelTrainer
|
||||
from coati.utils import load_checkpoint
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoTokenizer, AutoConfig
|
||||
from colossalai.shardformer.policies.auto_policy import get_autopolicy
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
|
@ -56,31 +57,10 @@ def train(args):
|
|||
)
|
||||
coordinator.print_on_master(msg="Flash-attention enabled successfully")
|
||||
else:
|
||||
model = RewardModel(args.pretrain)
|
||||
|
||||
if args.tp > 1:
|
||||
if model.model.config.architectures[0] == "BloomForCausalLM":
|
||||
from colossalai.shardformer.policies.bloom import BloomPolicy
|
||||
|
||||
booster_policy = BloomPolicy()
|
||||
elif model.model.config.architectures[0] == "LlamaForCausalLM":
|
||||
from colossalai.shardformer.policies.llama import LlamaPolicy
|
||||
|
||||
booster_policy = LlamaPolicy()
|
||||
elif model.model.config.architectures[0] == "GPT2LMHeadModel":
|
||||
from colossalai.shardformer.policies.gpt2 import GPT2Policy
|
||||
|
||||
booster_policy = GPT2Policy()
|
||||
elif model.model.config.architectures[0] == "ChatGLMModel":
|
||||
from colossalai.shardformer.policies.chatglm2 import ChatGLMPolicy
|
||||
|
||||
booster_policy = ChatGLMPolicy()
|
||||
elif model.model.config.architectures[0] == "OPTForCausalLM":
|
||||
from colossalai.shardformer.policies.opt import OPTPolicy
|
||||
|
||||
booster_policy = OPTPolicy()
|
||||
else:
|
||||
raise ValueError("Unknown model architecture for policy")
|
||||
model_config = AutoConfig.from_pretrained(args.pretrain)
|
||||
model = RewardModel(
|
||||
args.pretrain,
|
||||
)
|
||||
|
||||
if args.lora_rank > 0:
|
||||
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
|
@ -100,6 +80,7 @@ def train(args):
|
|||
placement_policy="static",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
enable_gradient_accumulation=True,
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
|
@ -107,6 +88,7 @@ def train(args):
|
|||
precision=args.mixed_precision,
|
||||
placement_policy="auto",
|
||||
initial_scale=2**16,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "zero2":
|
||||
|
@ -127,11 +109,17 @@ def train(args):
|
|||
elif args.plugin == "3d":
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=1,
|
||||
zero_stage=0,
|
||||
pp_size=args.pp,
|
||||
sp_size=args.sp,
|
||||
sequence_parallelism_mode=args.sp_mode,
|
||||
zero_stage=args.zero_stage,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
enable_sequence_parallelism=True if args.sp > 1 else False,
|
||||
cpu_offload=True if args.zero_stage>=1 and args.zero_cpu_offload else False,
|
||||
parallel_output=False,
|
||||
max_norm=args.grad_clip,
|
||||
precision=args.mixed_precision,
|
||||
custom_policy=booster_policy,
|
||||
custom_policy=get_autopolicy(model.model)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
|
@ -189,7 +177,7 @@ def train(args):
|
|||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=data_collator,
|
||||
use_tp=args.tp > 1,
|
||||
tp_size=args.tp,
|
||||
)
|
||||
|
||||
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
|
||||
|
@ -307,6 +295,11 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
|
||||
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
|
||||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--pp", type=int, default=1)
|
||||
parser.add_argument("--sp", type=int, default=1)
|
||||
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
|
||||
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
|
||||
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
|
||||
parser.add_argument("--pretrain", type=str, default=None)
|
||||
parser.add_argument("--tokenizer_dir", type=str, default=None)
|
||||
parser.add_argument("--dataset", nargs="+", default=[])
|
||||
|
|
|
@ -48,29 +48,29 @@ def train(args):
|
|||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
trust_remote_code=True)
|
||||
# check if the hybrid parallel plugin is compatible with the model
|
||||
# try:
|
||||
# from colossalai.shardformer.policies.auto_policy import get_autopolicy
|
||||
# policy = get_autopolicy(model)
|
||||
# if policy is not None:
|
||||
# if args.plugin in ['zero2', 'zero2_cpu']:
|
||||
# # if compatible, set the plugin to hybrid, which use colo-attention
|
||||
# args.plugin = 'hybrid'
|
||||
# args.zero_stage = 2
|
||||
# if args.plugin == 'zero2_cpu':
|
||||
# args.zero_cpu_offload = True
|
||||
# else:
|
||||
# args.zero_cpu_offload = False
|
||||
# logger.info(f"Model is compatible with hybrid parallel plugin, set plugin to {args.plugin} with zero_stage {args.zero_stage} and zero_cpu_offload {args.zero_cpu_offload}")
|
||||
# except NotImplementedError:
|
||||
# logger.warning(f"Unable to find a policy for the model, use {args.plugin} plugin instead")
|
||||
# if args.use_flash_attn:
|
||||
# del model
|
||||
# model = AutoModelForCausalLM.from_pretrained(
|
||||
# args.pretrain,
|
||||
# torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
# attn_implementation="flash_attention_2",
|
||||
# trust_remote_code=True
|
||||
# )
|
||||
try:
|
||||
from colossalai.shardformer.policies.auto_policy import get_autopolicy
|
||||
policy = get_autopolicy(model)
|
||||
if policy is not None:
|
||||
if args.plugin in ['zero2', 'zero2_cpu']:
|
||||
# if compatible, set the plugin to hybrid, which use colo-attention
|
||||
args.plugin = '3d'
|
||||
args.zero_stage = 2
|
||||
if args.plugin == 'zero2_cpu':
|
||||
args.zero_cpu_offload = True
|
||||
else:
|
||||
args.zero_cpu_offload = False
|
||||
logger.info(f"Model is compatible with hybrid parallel plugin, set plugin to {args.plugin} with zero_stage {args.zero_stage} and zero_cpu_offload {args.zero_cpu_offload}")
|
||||
except NotImplementedError:
|
||||
logger.warning(f"Unable to find a policy for the model, use {args.plugin} plugin instead")
|
||||
if args.use_flash_attn:
|
||||
del model
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.pretrain,
|
||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
attn_implementation="flash_attention_2",
|
||||
trust_remote_code=True
|
||||
)
|
||||
if args.lora_rank > 0:
|
||||
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
|
||||
|
@ -112,7 +112,7 @@ def train(args):
|
|||
cpu_offload=True,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "hybrid":
|
||||
elif args.plugin == "3d":
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=args.pp,
|
||||
|
@ -224,7 +224,6 @@ def train(args):
|
|||
lr_scheduler=lr_scheduler,
|
||||
dataloader=train_dataloader,
|
||||
)
|
||||
# model = model.to(get_current_device())
|
||||
torch.set_default_dtype(torch.float)
|
||||
|
||||
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
||||
|
@ -309,7 +308,7 @@ if __name__ == "__main__":
|
|||
"--plugin",
|
||||
type=str,
|
||||
default="gemini",
|
||||
choices=["gemini", "gemini_auto", "hybrid", "ddp", "zero2_cpu", "zero2"],
|
||||
choices=["gemini", "gemini_auto", "3d", "ddp", "zero2_cpu", "zero2"],
|
||||
help="Choose which plugin to use",
|
||||
)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
|
||||
|
|
|
@ -15,24 +15,24 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
|||
|
||||
|
||||
# export CUDA_VISIBLE_DEVICES=4,5,6
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 4
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 2
|
||||
PROJECT_NAME="sft"
|
||||
PARENT_SAVE_DIR="/home/yeanbang/data/experiment/output/model" # Path to a folder to save checkpoints
|
||||
PARENT_TENSORBOARD_DIR="/home/yeanbang/data/experiment/logs/tensorboard" # Path to a folder to save logs
|
||||
PARENT_CONFIG_FILE="/home/yeanbang/data/experiment/logs/config" # Path to a folder to save training config logs
|
||||
PRETRAINED_MODEL_PATH="/mnt/jfs-hdd/share/models/Llama-2-7b-chat-hf" # huggingface or local model path
|
||||
PRETRAINED_TOKENIZER_PATH="/mnt/jfs-hdd/share/models/Llama-2-7b-chat-hf" # huggingface or local tokenizer path
|
||||
PRETRAINED_MODEL_PATH="/mnt/jfs-hdd/share/models/Yi-1.5-6B" # huggingface or local model path
|
||||
PRETRAINED_TOKENIZER_PATH="/mnt/jfs-hdd/share/models/Yi-1.5-6B" # huggingface or local tokenizer path
|
||||
declare -a dataset=(
|
||||
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00000
|
||||
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00001
|
||||
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00002
|
||||
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00003
|
||||
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00004
|
||||
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00005
|
||||
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00006
|
||||
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00007
|
||||
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00008
|
||||
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00009
|
||||
/home/yeanbang/data/experiment/dataset/alpaca/test/Yi-1.5-6B/arrow/part-00000
|
||||
/home/yeanbang/data/experiment/dataset/alpaca/test/Yi-1.5-6B/arrow/part-00001
|
||||
/home/yeanbang/data/experiment/dataset/alpaca/test/Yi-1.5-6B/arrow/part-00002
|
||||
/home/yeanbang/data/experiment/dataset/alpaca/test/Yi-1.5-6B/arrow/part-00003
|
||||
/home/yeanbang/data/experiment/dataset/alpaca/test/Yi-1.5-6B/arrow/part-00004
|
||||
/home/yeanbang/data/experiment/dataset/alpaca/test/Yi-1.5-6B/arrow/part-00005
|
||||
/home/yeanbang/data/experiment/dataset/alpaca/test/Yi-1.5-6B/arrow/part-00006
|
||||
/home/yeanbang/data/experiment/dataset/alpaca/test/Yi-1.5-6B/arrow/part-00007
|
||||
/home/yeanbang/data/experiment/dataset/alpaca/test/Yi-1.5-6B/arrow/part-00008
|
||||
/home/yeanbang/data/experiment/dataset/alpaca/test/Yi-1.5-6B/arrow/part-00009
|
||||
)
|
||||
|
||||
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
|
||||
|
@ -43,7 +43,7 @@ CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
|
|||
echo $(which colossalai)
|
||||
echo $(which python)
|
||||
# the real batch size for gradient descent is number_of_node_in_hostfile * nproc_per_node * train_batch_size
|
||||
colossalai run --nproc_per_node 4 --master_port 31312 --hostfile ./hostfile train_sft.py \
|
||||
colossalai run --nproc_per_node 2 --master_port 31312 --hostfile ./hostfile train_sft.py \
|
||||
--pretrain $PRETRAINED_MODEL_PATH \
|
||||
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
|
||||
--save_interval 4000 \
|
||||
|
@ -51,13 +51,13 @@ colossalai run --nproc_per_node 4 --master_port 31312 --hostfile ./hostfile trai
|
|||
--save_path $SAVE_DIR \
|
||||
--config_file $CONFIG_FILE \
|
||||
--lora_rank 0 \
|
||||
--plugin zero2 \
|
||||
--tp 1 \
|
||||
--plugin 3d \
|
||||
--tp 2 \
|
||||
--pp 1 \
|
||||
--zero_stage 2 \
|
||||
--batch_size 4 \
|
||||
--zero_stage 0 \
|
||||
--batch_size 2 \
|
||||
--max_epochs 3 \
|
||||
--accumulation_steps 4 \
|
||||
--accumulation_steps 1 \
|
||||
--lr 5e-5 \
|
||||
--max_len 400 \
|
||||
--grad_checkpoint \
|
||||
|
|
|
@ -4,5 +4,6 @@
|
|||
"stop_ids": [
|
||||
29871,
|
||||
2
|
||||
]
|
||||
],
|
||||
"end_of_assistant": "</s>"
|
||||
}
|
|
@ -9,14 +9,13 @@ model_data_mapping = {
|
|||
'THUDM/chatglm2-6b': 'THUDM_chatglm2-6b.json',
|
||||
'THUDM/chatglm3-6b': 'THUDM_chatglm3-6b.json',
|
||||
'baichuan-inc/Baichuan2-13B-Chat': 'baichuan-inc_Baichuan2-13B-Chat.json',
|
||||
'Qwen/Qwen-7B-Chat': 'Qwen_Qwen-7B-Chat.json',
|
||||
'01-ai/Yi-1.5-9B-Chat': '01-ai_Yi-1.5-9B-Chat.json',
|
||||
'01-ai/Yi-34B': '01-ai_Yi-34B.json',
|
||||
'deepseek-ai/DeepSeek-V2-Lite': 'deepseek-ai_DeepSeek-V2-Lite.json',
|
||||
'microsoft/phi-2': 'microsoft_phi-2.json',
|
||||
'mistralai/Mixtral-8x7B-Instruct-v0.1': 'mistralai_Mixtral-8x7B-Instruct-v0.1.json'
|
||||
}
|
||||
chat_template_config_path = '../config/conversation_template'
|
||||
chat_template_config_path = './config/conversation_template'
|
||||
|
||||
|
||||
def test_tokenization_sft():
|
||||
|
@ -34,5 +33,5 @@ def test_tokenization_sft():
|
|||
)
|
||||
|
||||
output = supervised_tokenize_sft({"messages": messages}, tokenizer, conversation_template)
|
||||
with open(f"./test_data/chat_template/{model_data_mapping[model]}", "r", encoding="utf8") as f:
|
||||
with open(f"./tests/test_data/chat_template/{model_data_mapping[model]}", "r", encoding="utf8") as f:
|
||||
assert json.dumps(json.load(f)) == json.dumps(output), f"model: {model} failed"
|
||||
|
|
|
@ -6,35 +6,59 @@ TEST_DATA_DIR=$BASE_DIR/tests/test_data
|
|||
DATA_SAVE_PATH=$BASE_TEMP_DIR/tests
|
||||
CONFIG_DIR=$BASE_DIR/config
|
||||
|
||||
MODELS=("colossal-llama2" "llama2" "zephyr" "mistral" "chatGLM2" "Qwen" "Vicuna" "Yi")
|
||||
|
||||
MODELS=("colossal-llama2" "llama2" "mistral" "chatGLM2" "chatGLM3" "deepseek" "Yi" "baichuan")
|
||||
#
|
||||
get_pretrain() {
|
||||
local model=$1
|
||||
if [[ $model == "colossal-llama2" ]]; then
|
||||
echo "hpcai-tech/Colossal-LLaMA-2-7b-base"
|
||||
elif [[ $model == "llama2" ]]; then
|
||||
echo "hf-internal-testing/llama-tokenizer"
|
||||
elif [[ $model == "zephyr" ]]; then
|
||||
echo "HuggingFaceH4/zephyr-7b-beta"
|
||||
elif [[ $model == "phi" ]]; then
|
||||
echo "microsoft/phi-2"
|
||||
elif [[ $model == "mistral" ]]; then
|
||||
echo "mistralai/Mistral-7B-Instruct-v0.2"
|
||||
echo "mistralai/Mistral-7B-Instruct-v0.3"
|
||||
elif [[ $model == "chatGLM2" ]]; then
|
||||
echo "THUDM/chatglm2-6b"
|
||||
elif [[ $model == "Qwen" ]]; then
|
||||
echo "Qwen/Qwen-7B-Chat"
|
||||
elif [[ $model == "Vicuna" ]]; then
|
||||
echo "lmsys/vicuna-7b-v1.5"
|
||||
elif [[ $model == "chatGLM3" ]]; then
|
||||
echo "THUDM/chatglm3-6b"
|
||||
elif [[ $model == "deepseek" ]]; then
|
||||
echo "deepseek-ai/DeepSeek-V2-Lite"
|
||||
elif [[ $model == "Yi" ]]; then
|
||||
echo "01-ai/Yi-6B-Chat"
|
||||
echo "01-ai/Yi-1.5-9B-Chat"
|
||||
elif [[ $model == "baichuan" ]]; then
|
||||
echo "baichuan-inc/Baichuan2-13B-Chat"
|
||||
else
|
||||
echo "Unknown model $model"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
|
||||
get_conversation_template_config() {
|
||||
local model=$1
|
||||
echo "$CONFIG_DIR/conversation_template/$model.json"
|
||||
if [[ $model == "colossal-llama2" ]]; then
|
||||
echo "$CONFIG_DIR/conversation_template/colossal-llama2.json"
|
||||
elif [[ $model == "llama2" ]]; then
|
||||
echo "$CONFIG_DIR/conversation_template/llama2.json"
|
||||
elif [[ $model == "deepseek" ]]; then
|
||||
echo "$CONFIG_DIR/conversation_template/deepseek-ai_DeepSeek-V2-Lite.json"
|
||||
elif [[ $model == "mistral" ]]; then
|
||||
echo "$CONFIG_DIR/conversation_template/mistralai_Mixtral-8x7B-Instruct-v0.1.json"
|
||||
elif [[ $model == "chatGLM2" ]]; then
|
||||
echo "$CONFIG_DIR/conversation_template/THUDM_chatglm2-6b.json"
|
||||
elif [[ $model == "chatGLM3" ]]; then
|
||||
echo "$CONFIG_DIR/conversation_template/THUDM_chatglm3-6b.json"
|
||||
elif [[ $model == "phi" ]]; then
|
||||
echo "$CONFIG_DIR/conversation_template/microsoft_phi-2.json"
|
||||
elif [[ $model == "Yi" ]]; then
|
||||
echo "$CONFIG_DIR/conversation_template/01-ai_Yi-1.5-9B-Chat.json"
|
||||
elif [[ $model == "baichuan" ]]; then
|
||||
echo "$CONFIG_DIR/conversation_template/baichuan-inc_Baichuan2-13B-Chat.json"
|
||||
else
|
||||
echo "Unknown model $model"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
# Test SFT data Preparation
|
||||
|
|
|
@ -30,7 +30,7 @@ MODEL_SAVE_PATH=$TEMP_DIR/rlhf_models
|
|||
MODELS_DIR=$TEMP_DIR/models_config
|
||||
# Skip those tests due to CI tests timeout
|
||||
MODELS=('llama')
|
||||
PLUGINS=('gemini' 'gemini_auto' 'zero2' 'zero2_cpu' '3d')
|
||||
PLUGINS=('3d' 'gemini' 'gemini_auto' 'zero2' 'zero2_cpu')
|
||||
LORA_RANK=('0') # skip to reduce CI execution time, can pass all locally
|
||||
|
||||
export OMP_NUM_THREADS=8
|
||||
|
|
Loading…
Reference in New Issue