upgrade ppo dpo rm script

pull/5759/head
YeAnbang 2024-05-28 03:04:39 +00:00
parent 7a7e86987d
commit 929e1e3da4
15 changed files with 169 additions and 139 deletions

View File

@ -247,7 +247,7 @@ def apply_rlhf_data_format(
target_turn = int(len(template.messages) / 2) target_turn = int(len(template.messages) / 2)
prompt = template.get_prompt(target_turn * 2) prompt = template.get_prompt(target_turn * 2)
chunks, require_loss = split_templated_prompt_into_chunks(template.messages[: 2 * target_turn], prompt, chunks, require_loss = split_templated_prompt_into_chunks(template.messages[: 2 * target_turn], prompt,
tempalte.end_of_assistant) template.end_of_assistant)
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss) tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
loss_mask = [0] * len(tokenized) loss_mask = [0] * len(tokenized)
mask_token = tokenizer.eos_token_id or tokenizer.pad_token_id mask_token = tokenizer.eos_token_id or tokenizer.pad_token_id

View File

@ -32,3 +32,9 @@ class Critic(BaseModel):
) )
values = self.value_head(sequence_hidden_states).squeeze(-1) # ensure shape is (B, sequence length) values = self.value_head(sequence_hidden_states).squeeze(-1) # ensure shape is (B, sequence length)
return values return values
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def get_output_embeddings(self):
return self.model.get_output_embeddings()

View File

@ -36,3 +36,9 @@ class RewardModel(BaseModel):
) )
values = self.value_head(sequence_hidden_states).squeeze(-1) # Ensure shape is (B,) values = self.value_head(sequence_hidden_states).squeeze(-1) # Ensure shape is (B,)
return values return values
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def get_output_embeddings(self):
return self.model.get_output_embeddings()

View File

@ -1,8 +0,0 @@
{
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true and not '<<SYS>>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\\n\\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don\\'t know the answer to a question, please don\\'t share false information.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>>\\n' + content.strip() + '\\n<</SYS>>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}",
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
"stop_ids": [
2
],
"end_of_assistant": "<|endoftext|>"
}

View File

@ -0,0 +1,9 @@
{
"chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
"stop_ids": [
151645,
151643
],
"end_of_assistant": "<|im_end|>"
}

View File

@ -1,4 +1,4 @@
SAVE_DIR="/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf" SAVE_DIR="/home/yeanbang/data/experiment/dataset/alpaca/test/Yi-1.5-6B"
rm -rf $SAVE_DIR/cache rm -rf $SAVE_DIR/cache
rm -rf $SAVE_DIR/jsonl rm -rf $SAVE_DIR/jsonl
@ -14,9 +14,9 @@ rm -rf $SAVE_DIR/arrow
python prepare_dataset.py --type sft \ python prepare_dataset.py --type sft \
--data_input_dirs /mnt/jfs-hdd/home/yeanbang/data/dataset/sft_data/alpaca/data_preprocessed/train \ --data_input_dirs /home/yeanbang/data/experiment/dataset/sft_data/test/sft-data \
--conversation_template_config /home/yeanbang/data/ColossalAI/applications/ColossalChat/config/conversation_template/llama2.json \ --conversation_template_config /home/yeanbang/data/ColossalAI/applications/ColossalChat/config/conversation_template/01-ai_Yi-1.5-9B-Chat.json \
--tokenizer_dir "/mnt/jfs-hdd/share/models/Llama-2-7b-chat-hf" \ --tokenizer_dir "/mnt/jfs-hdd/share/models/Yi-1.5-6B" \
--data_cache_dir $SAVE_DIR/cache \ --data_cache_dir $SAVE_DIR/cache \
--data_jsonl_output_dir $SAVE_DIR/jsonl \ --data_jsonl_output_dir $SAVE_DIR/jsonl \
--data_arrow_output_dir $SAVE_DIR/arrow \ --data_arrow_output_dir $SAVE_DIR/arrow \

View File

@ -56,6 +56,7 @@ def train(args):
initial_scale=2**16, initial_scale=2**16,
max_norm=args.grad_clip, max_norm=args.grad_clip,
enable_gradient_accumulation=True, enable_gradient_accumulation=True,
enable_flash_attention=args.use_flash_attn
) )
elif args.plugin == "gemini_auto": elif args.plugin == "gemini_auto":
plugin = GeminiPlugin( plugin = GeminiPlugin(
@ -63,6 +64,7 @@ def train(args):
placement_policy="auto", placement_policy="auto",
initial_scale=2**16, initial_scale=2**16,
max_norm=args.grad_clip, max_norm=args.grad_clip,
enable_flash_attention=args.use_flash_attn
) )
elif args.plugin == "zero2": elif args.plugin == "zero2":
plugin = LowLevelZeroPlugin( plugin = LowLevelZeroPlugin(
@ -82,9 +84,15 @@ def train(args):
elif args.plugin == "3d": elif args.plugin == "3d":
plugin = HybridParallelPlugin( plugin = HybridParallelPlugin(
tp_size=args.tp, tp_size=args.tp,
pp_size=1, pp_size=args.pp,
zero_stage=0, sp_size=args.sp,
sequence_parallelism_mode=args.sp_mode,
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
enable_sequence_parallelism=True if args.sp > 1 else False,
cpu_offload=True if args.zero_stage>=1 and args.zero_cpu_offload else False,
parallel_output=False, parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision, precision=args.mixed_precision,
) )
else: else:
@ -172,7 +180,7 @@ def train(args):
shuffle=True, shuffle=True,
drop_last=True, drop_last=True,
collate_fn=data_collator, collate_fn=data_collator,
use_tp=args.tp > 1, tp_size=args.tp,
) )
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
@ -290,6 +298,11 @@ if __name__ == "__main__":
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay") parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps") parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
parser.add_argument("--tp", type=int, default=1) parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--sp", type=int, default=1)
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
parser.add_argument("--pretrain", type=str, default=None) parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--model_type", type=str, default=None) parser.add_argument("--model_type", type=str, default=None)
parser.add_argument("--tokenizer_dir", type=str, default=None) parser.add_argument("--tokenizer_dir", type=str, default=None)

View File

@ -18,6 +18,7 @@ from coati.models import Critic, RewardModel, convert_to_lora_module, disable_dr
from coati.trainer import PPOTrainer from coati.trainer import PPOTrainer
from coati.utils import load_checkpoint from coati.utils import load_checkpoint
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from colossalai.shardformer.policies.auto_policy import get_autopolicy
import colossalai import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
@ -86,32 +87,6 @@ def train(args):
disable_dropout(actor) disable_dropout(actor)
disable_dropout(critic) disable_dropout(critic)
if args.tp > 1:
if reward_model.model.config.architectures[0] != critic.model.config.architectures[0]:
raise ValueError("Reward model and critic model must have the same architecture")
if reward_model.model.config.architectures[0] == "BloomForCausalLM":
from colossalai.shardformer.policies.bloom import BloomPolicy
booster_policy = BloomPolicy()
elif reward_model.model.config.architectures[0] == "LlamaForCausalLM":
from colossalai.shardformer.policies.llama import LlamaPolicy
booster_policy = LlamaPolicy()
elif reward_model.model.config.architectures[0] == "GPT2LMHeadModel":
from colossalai.shardformer.policies.gpt2 import GPT2Policy
booster_policy = GPT2Policy()
elif reward_model.model.config.architectures[0] == "ChatGLMModel":
from colossalai.shardformer.policies.chatglm2 import ChatGLMPolicy
booster_policy = ChatGLMPolicy()
elif reward_model.model.config.architectures[0] == "OPTForCausalLM":
from colossalai.shardformer.policies.opt import OPTPolicy
booster_policy = OPTPolicy()
else:
raise ValueError("Unknown model architecture for policy")
if args.lora_rank > 0: if args.lora_rank > 0:
actor = convert_to_lora_module(actor, args.lora_rank, lora_train_bias=args.lora_train_bias) actor = convert_to_lora_module(actor, args.lora_rank, lora_train_bias=args.lora_train_bias)
critic = convert_to_lora_module(critic, args.lora_rank, lora_train_bias=args.lora_train_bias) critic = convert_to_lora_module(critic, args.lora_rank, lora_train_bias=args.lora_train_bias)
@ -186,7 +161,7 @@ def train(args):
shuffle=True, shuffle=True,
drop_last=True, drop_last=True,
collate_fn=data_collator, collate_fn=data_collator,
use_tp=args.tp > 1, tp_size=args.tp,
) )
if len(args.ptx_dataset) > 0: if len(args.ptx_dataset) > 0:
@ -198,7 +173,7 @@ def train(args):
shuffle=True, shuffle=True,
drop_last=True, drop_last=True,
collate_fn=data_collator, collate_fn=data_collator,
use_tp=args.tp > 1, tp_size=args.tp,
) )
else: else:
train_pretrain_dataloader = None train_pretrain_dataloader = None
@ -237,6 +212,7 @@ def train(args):
initial_scale=2**16, initial_scale=2**16,
max_norm=args.grad_clip, max_norm=args.grad_clip,
enable_gradient_accumulation=True, enable_gradient_accumulation=True,
enable_flash_attention=args.use_flash_attn
) )
elif args.plugin == "gemini_auto": elif args.plugin == "gemini_auto":
plugin = GeminiPlugin( plugin = GeminiPlugin(
@ -244,6 +220,7 @@ def train(args):
placement_policy="auto", placement_policy="auto",
initial_scale=2**16, initial_scale=2**16,
max_norm=args.grad_clip, max_norm=args.grad_clip,
enable_flash_attention=args.use_flash_attn
) )
elif args.plugin == "zero2": elif args.plugin == "zero2":
plugin = LowLevelZeroPlugin( plugin = LowLevelZeroPlugin(
@ -270,11 +247,17 @@ def train(args):
) )
custom_plugin = HybridParallelPlugin( custom_plugin = HybridParallelPlugin(
tp_size=args.tp, tp_size=args.tp,
pp_size=1, pp_size=args.pp,
zero_stage=0, sp_size=args.sp,
sequence_parallelism_mode=args.sp_mode,
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
enable_sequence_parallelism=True if args.sp > 1 else False,
cpu_offload=True if args.zero_stage>=1 and args.zero_cpu_offload else False,
parallel_output=False, parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision, precision=args.mixed_precision,
custom_policy=booster_policy, custom_policy=get_autopolicy(reward_model.model),
) )
else: else:
raise ValueError(f"Unknown plugin {args.plugin}") raise ValueError(f"Unknown plugin {args.plugin}")
@ -474,6 +457,11 @@ if __name__ == "__main__":
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps") parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
parser.add_argument("--tokenizer_dir", type=str, default=None) parser.add_argument("--tokenizer_dir", type=str, default=None)
parser.add_argument("--tp", type=int, default=1) parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--sp", type=int, default=1)
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
parser.add_argument("--pretrain", type=str, default=None) parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--rm_pretrain", type=str, default=None) parser.add_argument("--rm_pretrain", type=str, default=None)
parser.add_argument("--checkpoint_path", type=str, default=None) parser.add_argument("--checkpoint_path", type=str, default=None)

View File

@ -15,7 +15,8 @@ from coati.dataset import (
from coati.models import LogExpLoss, LogSigLoss, RewardModel, convert_to_lora_module from coati.models import LogExpLoss, LogSigLoss, RewardModel, convert_to_lora_module
from coati.trainer import RewardModelTrainer from coati.trainer import RewardModelTrainer
from coati.utils import load_checkpoint from coati.utils import load_checkpoint
from transformers import AutoTokenizer from transformers import AutoTokenizer, AutoConfig
from colossalai.shardformer.policies.auto_policy import get_autopolicy
import colossalai import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
@ -56,31 +57,10 @@ def train(args):
) )
coordinator.print_on_master(msg="Flash-attention enabled successfully") coordinator.print_on_master(msg="Flash-attention enabled successfully")
else: else:
model = RewardModel(args.pretrain) model_config = AutoConfig.from_pretrained(args.pretrain)
model = RewardModel(
if args.tp > 1: args.pretrain,
if model.model.config.architectures[0] == "BloomForCausalLM": )
from colossalai.shardformer.policies.bloom import BloomPolicy
booster_policy = BloomPolicy()
elif model.model.config.architectures[0] == "LlamaForCausalLM":
from colossalai.shardformer.policies.llama import LlamaPolicy
booster_policy = LlamaPolicy()
elif model.model.config.architectures[0] == "GPT2LMHeadModel":
from colossalai.shardformer.policies.gpt2 import GPT2Policy
booster_policy = GPT2Policy()
elif model.model.config.architectures[0] == "ChatGLMModel":
from colossalai.shardformer.policies.chatglm2 import ChatGLMPolicy
booster_policy = ChatGLMPolicy()
elif model.model.config.architectures[0] == "OPTForCausalLM":
from colossalai.shardformer.policies.opt import OPTPolicy
booster_policy = OPTPolicy()
else:
raise ValueError("Unknown model architecture for policy")
if args.lora_rank > 0: if args.lora_rank > 0:
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias) model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
@ -100,6 +80,7 @@ def train(args):
placement_policy="static", placement_policy="static",
initial_scale=2**16, initial_scale=2**16,
max_norm=args.grad_clip, max_norm=args.grad_clip,
enable_flash_attention=args.use_flash_attn,
enable_gradient_accumulation=True, enable_gradient_accumulation=True,
) )
elif args.plugin == "gemini_auto": elif args.plugin == "gemini_auto":
@ -107,6 +88,7 @@ def train(args):
precision=args.mixed_precision, precision=args.mixed_precision,
placement_policy="auto", placement_policy="auto",
initial_scale=2**16, initial_scale=2**16,
enable_flash_attention=args.use_flash_attn,
max_norm=args.grad_clip, max_norm=args.grad_clip,
) )
elif args.plugin == "zero2": elif args.plugin == "zero2":
@ -127,11 +109,17 @@ def train(args):
elif args.plugin == "3d": elif args.plugin == "3d":
plugin = HybridParallelPlugin( plugin = HybridParallelPlugin(
tp_size=args.tp, tp_size=args.tp,
pp_size=1, pp_size=args.pp,
zero_stage=0, sp_size=args.sp,
sequence_parallelism_mode=args.sp_mode,
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
enable_sequence_parallelism=True if args.sp > 1 else False,
cpu_offload=True if args.zero_stage>=1 and args.zero_cpu_offload else False,
parallel_output=False, parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision, precision=args.mixed_precision,
custom_policy=booster_policy, custom_policy=get_autopolicy(model.model)
) )
else: else:
raise ValueError(f"Unknown plugin {args.plugin}") raise ValueError(f"Unknown plugin {args.plugin}")
@ -189,7 +177,7 @@ def train(args):
shuffle=True, shuffle=True,
drop_last=True, drop_last=True,
collate_fn=data_collator, collate_fn=data_collator,
use_tp=args.tp > 1, tp_size=args.tp,
) )
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
@ -307,6 +295,11 @@ if __name__ == "__main__":
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay") parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps") parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
parser.add_argument("--tp", type=int, default=1) parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--sp", type=int, default=1)
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
parser.add_argument("--pretrain", type=str, default=None) parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--tokenizer_dir", type=str, default=None) parser.add_argument("--tokenizer_dir", type=str, default=None)
parser.add_argument("--dataset", nargs="+", default=[]) parser.add_argument("--dataset", nargs="+", default=[])

View File

@ -48,29 +48,29 @@ def train(args):
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
trust_remote_code=True) trust_remote_code=True)
# check if the hybrid parallel plugin is compatible with the model # check if the hybrid parallel plugin is compatible with the model
# try: try:
# from colossalai.shardformer.policies.auto_policy import get_autopolicy from colossalai.shardformer.policies.auto_policy import get_autopolicy
# policy = get_autopolicy(model) policy = get_autopolicy(model)
# if policy is not None: if policy is not None:
# if args.plugin in ['zero2', 'zero2_cpu']: if args.plugin in ['zero2', 'zero2_cpu']:
# # if compatible, set the plugin to hybrid, which use colo-attention # if compatible, set the plugin to hybrid, which use colo-attention
# args.plugin = 'hybrid' args.plugin = '3d'
# args.zero_stage = 2 args.zero_stage = 2
# if args.plugin == 'zero2_cpu': if args.plugin == 'zero2_cpu':
# args.zero_cpu_offload = True args.zero_cpu_offload = True
# else: else:
# args.zero_cpu_offload = False args.zero_cpu_offload = False
# logger.info(f"Model is compatible with hybrid parallel plugin, set plugin to {args.plugin} with zero_stage {args.zero_stage} and zero_cpu_offload {args.zero_cpu_offload}") logger.info(f"Model is compatible with hybrid parallel plugin, set plugin to {args.plugin} with zero_stage {args.zero_stage} and zero_cpu_offload {args.zero_cpu_offload}")
# except NotImplementedError: except NotImplementedError:
# logger.warning(f"Unable to find a policy for the model, use {args.plugin} plugin instead") logger.warning(f"Unable to find a policy for the model, use {args.plugin} plugin instead")
# if args.use_flash_attn: if args.use_flash_attn:
# del model del model
# model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
# args.pretrain, args.pretrain,
# torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
# attn_implementation="flash_attention_2", attn_implementation="flash_attention_2",
# trust_remote_code=True trust_remote_code=True
# ) )
if args.lora_rank > 0: if args.lora_rank > 0:
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias) model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
@ -112,7 +112,7 @@ def train(args):
cpu_offload=True, cpu_offload=True,
max_norm=args.grad_clip, max_norm=args.grad_clip,
) )
elif args.plugin == "hybrid": elif args.plugin == "3d":
plugin = HybridParallelPlugin( plugin = HybridParallelPlugin(
tp_size=args.tp, tp_size=args.tp,
pp_size=args.pp, pp_size=args.pp,
@ -224,7 +224,6 @@ def train(args):
lr_scheduler=lr_scheduler, lr_scheduler=lr_scheduler,
dataloader=train_dataloader, dataloader=train_dataloader,
) )
# model = model.to(get_current_device())
torch.set_default_dtype(torch.float) torch.set_default_dtype(torch.float)
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB") coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
@ -309,7 +308,7 @@ if __name__ == "__main__":
"--plugin", "--plugin",
type=str, type=str,
default="gemini", default="gemini",
choices=["gemini", "gemini_auto", "hybrid", "ddp", "zero2_cpu", "zero2"], choices=["gemini", "gemini_auto", "3d", "ddp", "zero2_cpu", "zero2"],
help="Choose which plugin to use", help="Choose which plugin to use",
) )
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value") parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")

View File

@ -15,24 +15,24 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
# export CUDA_VISIBLE_DEVICES=4,5,6 # export CUDA_VISIBLE_DEVICES=4,5,6
set_n_least_used_CUDA_VISIBLE_DEVICES 4 set_n_least_used_CUDA_VISIBLE_DEVICES 2
PROJECT_NAME="sft" PROJECT_NAME="sft"
PARENT_SAVE_DIR="/home/yeanbang/data/experiment/output/model" # Path to a folder to save checkpoints PARENT_SAVE_DIR="/home/yeanbang/data/experiment/output/model" # Path to a folder to save checkpoints
PARENT_TENSORBOARD_DIR="/home/yeanbang/data/experiment/logs/tensorboard" # Path to a folder to save logs PARENT_TENSORBOARD_DIR="/home/yeanbang/data/experiment/logs/tensorboard" # Path to a folder to save logs
PARENT_CONFIG_FILE="/home/yeanbang/data/experiment/logs/config" # Path to a folder to save training config logs PARENT_CONFIG_FILE="/home/yeanbang/data/experiment/logs/config" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="/mnt/jfs-hdd/share/models/Llama-2-7b-chat-hf" # huggingface or local model path PRETRAINED_MODEL_PATH="/mnt/jfs-hdd/share/models/Yi-1.5-6B" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="/mnt/jfs-hdd/share/models/Llama-2-7b-chat-hf" # huggingface or local tokenizer path PRETRAINED_TOKENIZER_PATH="/mnt/jfs-hdd/share/models/Yi-1.5-6B" # huggingface or local tokenizer path
declare -a dataset=( declare -a dataset=(
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00000 /home/yeanbang/data/experiment/dataset/alpaca/test/Yi-1.5-6B/arrow/part-00000
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00001 /home/yeanbang/data/experiment/dataset/alpaca/test/Yi-1.5-6B/arrow/part-00001
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00002 /home/yeanbang/data/experiment/dataset/alpaca/test/Yi-1.5-6B/arrow/part-00002
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00003 /home/yeanbang/data/experiment/dataset/alpaca/test/Yi-1.5-6B/arrow/part-00003
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00004 /home/yeanbang/data/experiment/dataset/alpaca/test/Yi-1.5-6B/arrow/part-00004
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00005 /home/yeanbang/data/experiment/dataset/alpaca/test/Yi-1.5-6B/arrow/part-00005
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00006 /home/yeanbang/data/experiment/dataset/alpaca/test/Yi-1.5-6B/arrow/part-00006
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00007 /home/yeanbang/data/experiment/dataset/alpaca/test/Yi-1.5-6B/arrow/part-00007
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00008 /home/yeanbang/data/experiment/dataset/alpaca/test/Yi-1.5-6B/arrow/part-00008
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00009 /home/yeanbang/data/experiment/dataset/alpaca/test/Yi-1.5-6B/arrow/part-00009
) )
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S) TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
@ -43,7 +43,7 @@ CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
echo $(which colossalai) echo $(which colossalai)
echo $(which python) echo $(which python)
# the real batch size for gradient descent is number_of_node_in_hostfile * nproc_per_node * train_batch_size # the real batch size for gradient descent is number_of_node_in_hostfile * nproc_per_node * train_batch_size
colossalai run --nproc_per_node 4 --master_port 31312 --hostfile ./hostfile train_sft.py \ colossalai run --nproc_per_node 2 --master_port 31312 --hostfile ./hostfile train_sft.py \
--pretrain $PRETRAINED_MODEL_PATH \ --pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \ --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--save_interval 4000 \ --save_interval 4000 \
@ -51,13 +51,13 @@ colossalai run --nproc_per_node 4 --master_port 31312 --hostfile ./hostfile trai
--save_path $SAVE_DIR \ --save_path $SAVE_DIR \
--config_file $CONFIG_FILE \ --config_file $CONFIG_FILE \
--lora_rank 0 \ --lora_rank 0 \
--plugin zero2 \ --plugin 3d \
--tp 1 \ --tp 2 \
--pp 1 \ --pp 1 \
--zero_stage 2 \ --zero_stage 0 \
--batch_size 4 \ --batch_size 2 \
--max_epochs 3 \ --max_epochs 3 \
--accumulation_steps 4 \ --accumulation_steps 1 \
--lr 5e-5 \ --lr 5e-5 \
--max_len 400 \ --max_len 400 \
--grad_checkpoint \ --grad_checkpoint \

View File

@ -4,5 +4,6 @@
"stop_ids": [ "stop_ids": [
29871, 29871,
2 2
] ],
} "end_of_assistant": "</s>"
}

View File

@ -9,14 +9,13 @@ model_data_mapping = {
'THUDM/chatglm2-6b': 'THUDM_chatglm2-6b.json', 'THUDM/chatglm2-6b': 'THUDM_chatglm2-6b.json',
'THUDM/chatglm3-6b': 'THUDM_chatglm3-6b.json', 'THUDM/chatglm3-6b': 'THUDM_chatglm3-6b.json',
'baichuan-inc/Baichuan2-13B-Chat': 'baichuan-inc_Baichuan2-13B-Chat.json', 'baichuan-inc/Baichuan2-13B-Chat': 'baichuan-inc_Baichuan2-13B-Chat.json',
'Qwen/Qwen-7B-Chat': 'Qwen_Qwen-7B-Chat.json',
'01-ai/Yi-1.5-9B-Chat': '01-ai_Yi-1.5-9B-Chat.json', '01-ai/Yi-1.5-9B-Chat': '01-ai_Yi-1.5-9B-Chat.json',
'01-ai/Yi-34B': '01-ai_Yi-34B.json', '01-ai/Yi-34B': '01-ai_Yi-34B.json',
'deepseek-ai/DeepSeek-V2-Lite': 'deepseek-ai_DeepSeek-V2-Lite.json', 'deepseek-ai/DeepSeek-V2-Lite': 'deepseek-ai_DeepSeek-V2-Lite.json',
'microsoft/phi-2': 'microsoft_phi-2.json', 'microsoft/phi-2': 'microsoft_phi-2.json',
'mistralai/Mixtral-8x7B-Instruct-v0.1': 'mistralai_Mixtral-8x7B-Instruct-v0.1.json' 'mistralai/Mixtral-8x7B-Instruct-v0.1': 'mistralai_Mixtral-8x7B-Instruct-v0.1.json'
} }
chat_template_config_path = '../config/conversation_template' chat_template_config_path = './config/conversation_template'
def test_tokenization_sft(): def test_tokenization_sft():
@ -34,5 +33,5 @@ def test_tokenization_sft():
) )
output = supervised_tokenize_sft({"messages": messages}, tokenizer, conversation_template) output = supervised_tokenize_sft({"messages": messages}, tokenizer, conversation_template)
with open(f"./test_data/chat_template/{model_data_mapping[model]}", "r", encoding="utf8") as f: with open(f"./tests/test_data/chat_template/{model_data_mapping[model]}", "r", encoding="utf8") as f:
assert json.dumps(json.load(f)) == json.dumps(output), f"model: {model} failed" assert json.dumps(json.load(f)) == json.dumps(output), f"model: {model} failed"

View File

@ -6,35 +6,59 @@ TEST_DATA_DIR=$BASE_DIR/tests/test_data
DATA_SAVE_PATH=$BASE_TEMP_DIR/tests DATA_SAVE_PATH=$BASE_TEMP_DIR/tests
CONFIG_DIR=$BASE_DIR/config CONFIG_DIR=$BASE_DIR/config
MODELS=("colossal-llama2" "llama2" "zephyr" "mistral" "chatGLM2" "Qwen" "Vicuna" "Yi") MODELS=("colossal-llama2" "llama2" "mistral" "chatGLM2" "chatGLM3" "deepseek" "Yi" "baichuan")
#
get_pretrain() { get_pretrain() {
local model=$1 local model=$1
if [[ $model == "colossal-llama2" ]]; then if [[ $model == "colossal-llama2" ]]; then
echo "hpcai-tech/Colossal-LLaMA-2-7b-base" echo "hpcai-tech/Colossal-LLaMA-2-7b-base"
elif [[ $model == "llama2" ]]; then elif [[ $model == "llama2" ]]; then
echo "hf-internal-testing/llama-tokenizer" echo "hf-internal-testing/llama-tokenizer"
elif [[ $model == "zephyr" ]]; then elif [[ $model == "phi" ]]; then
echo "HuggingFaceH4/zephyr-7b-beta" echo "microsoft/phi-2"
elif [[ $model == "mistral" ]]; then elif [[ $model == "mistral" ]]; then
echo "mistralai/Mistral-7B-Instruct-v0.2" echo "mistralai/Mistral-7B-Instruct-v0.3"
elif [[ $model == "chatGLM2" ]]; then elif [[ $model == "chatGLM2" ]]; then
echo "THUDM/chatglm2-6b" echo "THUDM/chatglm2-6b"
elif [[ $model == "Qwen" ]]; then elif [[ $model == "chatGLM3" ]]; then
echo "Qwen/Qwen-7B-Chat" echo "THUDM/chatglm3-6b"
elif [[ $model == "Vicuna" ]]; then elif [[ $model == "deepseek" ]]; then
echo "lmsys/vicuna-7b-v1.5" echo "deepseek-ai/DeepSeek-V2-Lite"
elif [[ $model == "Yi" ]]; then elif [[ $model == "Yi" ]]; then
echo "01-ai/Yi-6B-Chat" echo "01-ai/Yi-1.5-9B-Chat"
elif [[ $model == "baichuan" ]]; then
echo "baichuan-inc/Baichuan2-13B-Chat"
else else
echo "Unknown model $model" echo "Unknown model $model"
exit 1 exit 1
fi fi
} }
get_conversation_template_config() { get_conversation_template_config() {
local model=$1 local model=$1
echo "$CONFIG_DIR/conversation_template/$model.json" if [[ $model == "colossal-llama2" ]]; then
echo "$CONFIG_DIR/conversation_template/colossal-llama2.json"
elif [[ $model == "llama2" ]]; then
echo "$CONFIG_DIR/conversation_template/llama2.json"
elif [[ $model == "deepseek" ]]; then
echo "$CONFIG_DIR/conversation_template/deepseek-ai_DeepSeek-V2-Lite.json"
elif [[ $model == "mistral" ]]; then
echo "$CONFIG_DIR/conversation_template/mistralai_Mixtral-8x7B-Instruct-v0.1.json"
elif [[ $model == "chatGLM2" ]]; then
echo "$CONFIG_DIR/conversation_template/THUDM_chatglm2-6b.json"
elif [[ $model == "chatGLM3" ]]; then
echo "$CONFIG_DIR/conversation_template/THUDM_chatglm3-6b.json"
elif [[ $model == "phi" ]]; then
echo "$CONFIG_DIR/conversation_template/microsoft_phi-2.json"
elif [[ $model == "Yi" ]]; then
echo "$CONFIG_DIR/conversation_template/01-ai_Yi-1.5-9B-Chat.json"
elif [[ $model == "baichuan" ]]; then
echo "$CONFIG_DIR/conversation_template/baichuan-inc_Baichuan2-13B-Chat.json"
else
echo "Unknown model $model"
exit 1
fi
} }
# Test SFT data Preparation # Test SFT data Preparation

View File

@ -30,7 +30,7 @@ MODEL_SAVE_PATH=$TEMP_DIR/rlhf_models
MODELS_DIR=$TEMP_DIR/models_config MODELS_DIR=$TEMP_DIR/models_config
# Skip those tests due to CI tests timeout # Skip those tests due to CI tests timeout
MODELS=('llama') MODELS=('llama')
PLUGINS=('gemini' 'gemini_auto' 'zero2' 'zero2_cpu' '3d') PLUGINS=('3d' 'gemini' 'gemini_auto' 'zero2' 'zero2_cpu')
LORA_RANK=('0') # skip to reduce CI execution time, can pass all locally LORA_RANK=('0') # skip to reduce CI execution time, can pass all locally
export OMP_NUM_THREADS=8 export OMP_NUM_THREADS=8