From 8cc8f645cd1d971a3bef52f625b7881f17c6d22b Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Fri, 19 Jul 2024 10:10:08 +0800 Subject: [PATCH 1/2] [Examples] Add lazy init to OPT and GPT examples (#5924) Co-authored-by: Edenzzzz --- .../gpt/hybridparallelism/finetune.py | 17 ++++++++---- examples/language/opt/opt_benchmark.py | 24 +++++++++++------ examples/language/opt/opt_train_demo.py | 27 +++++++++++++------ 3 files changed, 47 insertions(+), 21 deletions(-) diff --git a/examples/language/gpt/hybridparallelism/finetune.py b/examples/language/gpt/hybridparallelism/finetune.py index 777d16cb9..ae6d655f4 100644 --- a/examples/language/gpt/hybridparallelism/finetune.py +++ b/examples/language/gpt/hybridparallelism/finetune.py @@ -1,4 +1,5 @@ import argparse +from contextlib import nullcontext from typing import Callable, List, Union import evaluate @@ -17,6 +18,7 @@ from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator +from colossalai.lazy import LazyInitContext from colossalai.nn.optimizer import HybridAdam # ============================== @@ -186,7 +188,6 @@ def main(): help="only gpt2 now", ) parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached") - parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context") args = parser.parse_args() if args.model_type == "gpt2": @@ -250,10 +251,16 @@ def main(): pad_token_id=data_builder.tokenizer.pad_token_id, ) - if model_name == "gpt2": - model = GPT2ForSequenceClassification.from_pretrained(model_name, config=cfg).cuda() - else: - raise RuntimeError + init_ctx = ( + LazyInitContext(default_device=get_accelerator().get_current_device()) + if isinstance(plugin, (GeminiPlugin)) + else nullcontext() + ) + with init_ctx: + if model_name == "gpt2": + model = GPT2ForSequenceClassification.from_pretrained(model_name, config=cfg).cuda() + else: + raise RuntimeError # optimizer no_decay = ["bias", "LayerNorm.weight"] diff --git a/examples/language/opt/opt_benchmark.py b/examples/language/opt/opt_benchmark.py index c2883d96c..ca9b63d1a 100755 --- a/examples/language/opt/opt_benchmark.py +++ b/examples/language/opt/opt_benchmark.py @@ -1,4 +1,5 @@ import time +from contextlib import nullcontext import torch import tqdm @@ -8,9 +9,11 @@ from transformers import AutoConfig, OPTForCausalLM from transformers.utils.versions import require_version import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator +from colossalai.lazy import LazyInitContext from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam @@ -62,14 +65,6 @@ def main(): if args.mem_cap > 0: colo_memory_cap(args.mem_cap) - # Build OPT model - config = AutoConfig.from_pretrained(args.model_name_or_path) - model = OPTForCausalLM(config=config) - logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0]) - - # Enable gradient checkpointing - model.gradient_checkpointing_enable() - # Set plugin booster_kwargs = {} if args.plugin == "torch_ddp_fp16": @@ -82,6 +77,19 @@ def main(): plugin = LowLevelZeroPlugin(initial_scale=2**5) logger.info(f"Set plugin as {args.plugin}", ranks=[0]) + # Build OPT model + init_ctx = ( + LazyInitContext(default_device=get_accelerator().get_current_device()) + if isinstance(plugin, (GeminiPlugin)) + else nullcontext() + ) + config = AutoConfig.from_pretrained(args.model_name_or_path) + with init_ctx: + model = OPTForCausalLM(config=config) + logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0]) + + # Enable gradient checkpointing + model.gradient_checkpointing_enable() # Set optimizer optimizer = HybridAdam(model.parameters(), lr=args.learning_rate) diff --git a/examples/language/opt/opt_train_demo.py b/examples/language/opt/opt_train_demo.py index b5b50305c..50dfc7bff 100644 --- a/examples/language/opt/opt_train_demo.py +++ b/examples/language/opt/opt_train_demo.py @@ -1,3 +1,5 @@ +from contextlib import nullcontext + import datasets import torch import transformers @@ -8,9 +10,11 @@ from transformers import AutoConfig, AutoTokenizer, OPTForCausalLM, get_linear_s from transformers.utils.versions import require_version import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator +from colossalai.lazy import LazyInitContext from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam @@ -78,14 +82,6 @@ def main(): datasets.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() - # Build OPT model - config = AutoConfig.from_pretrained(args.model_name_or_path) - model = OPTForCausalLM.from_pretrained(args.model_name_or_path, config=config) - logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0]) - - # Enable gradient checkpointing - model.gradient_checkpointing_enable() - # Set plugin booster_kwargs = {} if args.plugin == "torch_ddp_fp16": @@ -110,6 +106,21 @@ def main(): logger.info(f"Set plugin as {args.plugin}", ranks=[0]) + # Build OPT model + config = AutoConfig.from_pretrained(args.model_name_or_path) + # Build OPT model + init_ctx = ( + LazyInitContext(default_device=get_accelerator().get_current_device()) + if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin)) + else nullcontext() + ) + with init_ctx: + model = OPTForCausalLM.from_pretrained(args.model_name_or_path, config=config) + logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0]) + + # Enable gradient checkpointing + model.gradient_checkpointing_enable() + # Prepare tokenizer and dataloader tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) dataset = NetflixDataset(tokenizer) From f585d4e38e4aa6145ec94fe4d6f0a1fe94bc4192 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Fri, 19 Jul 2024 13:40:07 +0800 Subject: [PATCH 2/2] [ColossalChat] Hotfix for ColossalChat (#5910) * add ignore and tiny llama * fix path issue * run style * fix issue * update bash * add ignore and tiny llama * fix path issue * run style * fix issue * update bash * fix ddp issue * add Qwen 1.5 32B --- applications/ColossalChat/.gitignore | 3 +++ applications/ColossalChat/coati/trainer/sft.py | 2 +- .../Qwen_Qwen1.5-32B-Chat.json | 9 +++++++++ .../conversation_template/tiny-llama.json | 8 ++++++++ applications/ColossalChat/examples/README.md | 18 +++++++++--------- .../examples/training_scripts/train_dpo.sh | 6 ++---- .../examples/training_scripts/train_orpo.sh | 8 +++----- .../examples/training_scripts/train_ppo.sh | 5 ++--- .../examples/training_scripts/train_rm.sh | 6 ++---- .../examples/training_scripts/train_sft.py | 2 +- .../examples/training_scripts/train_sft.sh | 5 ++--- 11 files changed, 42 insertions(+), 30 deletions(-) create mode 100644 applications/ColossalChat/config/conversation_template/Qwen_Qwen1.5-32B-Chat.json create mode 100644 applications/ColossalChat/config/conversation_template/tiny-llama.json diff --git a/applications/ColossalChat/.gitignore b/applications/ColossalChat/.gitignore index 33950adc0..757cbb5da 100755 --- a/applications/ColossalChat/.gitignore +++ b/applications/ColossalChat/.gitignore @@ -146,6 +146,9 @@ docs/.build examples/wandb/ examples/logs/ examples/output/ +examples/training_scripts/logs +examples/training_scripts/wandb +examples/training_scripts/output examples/awesome-chatgpt-prompts/ temp/ diff --git a/applications/ColossalChat/coati/trainer/sft.py b/applications/ColossalChat/coati/trainer/sft.py index 1484f5057..c09d61034 100755 --- a/applications/ColossalChat/coati/trainer/sft.py +++ b/applications/ColossalChat/coati/trainer/sft.py @@ -102,7 +102,6 @@ class SFTTrainer(SLTrainer): batch_size = batch["input_ids"].size(0) outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]) loss = outputs.loss - step_bar.set_description(f"Epoch {epoch + 1}/{self.max_epochs} Loss: {loss.detach().cpu().item():.4f}") self.booster.backward(loss=loss, optimizer=self.optimizer) @@ -115,6 +114,7 @@ class SFTTrainer(SLTrainer): self.optimizer.zero_grad() self.scheduler.step() + step_bar.set_postfix({"train/loss": self.accumulative_meter.get("loss")}) if self.writer: self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step) self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step) diff --git a/applications/ColossalChat/config/conversation_template/Qwen_Qwen1.5-32B-Chat.json b/applications/ColossalChat/config/conversation_template/Qwen_Qwen1.5-32B-Chat.json new file mode 100644 index 000000000..58941a591 --- /dev/null +++ b/applications/ColossalChat/config/conversation_template/Qwen_Qwen1.5-32B-Chat.json @@ -0,0 +1,9 @@ +{ + "chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", + "system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + "stop_ids": [ + 151645, + 151643 + ], + "end_of_assistant": "<|im_end|>" +} diff --git a/applications/ColossalChat/config/conversation_template/tiny-llama.json b/applications/ColossalChat/config/conversation_template/tiny-llama.json new file mode 100644 index 000000000..59196159f --- /dev/null +++ b/applications/ColossalChat/config/conversation_template/tiny-llama.json @@ -0,0 +1,8 @@ +{ + "chat_template": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}", + "system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + "stop_ids": [ + 2 + ], + "end_of_assistant": "" +} diff --git a/applications/ColossalChat/examples/README.md b/applications/ColossalChat/examples/README.md index bdf4d23f1..d6114c8d5 100755 --- a/applications/ColossalChat/examples/README.md +++ b/applications/ColossalChat/examples/README.md @@ -490,7 +490,7 @@ In this code we provide a flexible way for users to set the conversation templat On your first run of the data preparation script, you only need to define the "chat_template" (if you want to use custom chat template) and the "system message" (if you want to use a custom system message), -- Step 2: Run the data preparation script--- [prepare_sft_dataset.sh](./examples/data_preparation_scripts/prepare_sft_dataset.sh). Note that whether or not you have skipped the first step, you need to provide the path to the conversation template config file (via the conversation_template_config arg). If you skipped the first step, an auto-generated conversation template will be stored at the designated file path. +- Step 2: Run the data preparation script--- [prepare_sft_dataset.sh](./data_preparation_scripts/prepare_sft_dataset.sh). Note that whether or not you have skipped the first step, you need to provide the path to the conversation template config file (via the conversation_template_config arg). If you skipped the first step, an auto-generated conversation template will be stored at the designated file path. - Step 3: (Optional) Check the correctness of the processed data. We provided an easy way for you to do a manual checking on the processed data by checking the "$SAVE_DIR/jsonl/part-XXXX.jsonl" files. @@ -510,7 +510,7 @@ Human: what are some pranks with a pen i can do? Assistant: Are you #### Step 3: Training -Choose a suitable model architecture for your task. Note that your model should be compatible with the tokenizer that you used to tokenize the SFT dataset. You can run [train_sft.sh](./examples/training_scripts/train_sft.sh) to start a supervised instructs fine-tuning. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options. +Choose a suitable model architecture for your task. Note that your model should be compatible with the tokenizer that you used to tokenize the SFT dataset. You can run [train_sft.sh](./training_scripts/train_sft.sh) to start a supervised instructs fine-tuning. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options. ### RLHF Training Stage2 - Training Reward Model @@ -552,11 +552,11 @@ Below shows the preference dataset format used in training the reward model. #### Step 2: Preprocessing -Similar to the second step in the previous stage, we format the reward data into the same structured format as used in step 2 of the SFT stage. You can run [prepare_preference_dataset.sh](./examples/data_preparation_scripts/prepare_preference_dataset.sh) to prepare the preference data for reward model training. +Similar to the second step in the previous stage, we format the reward data into the same structured format as used in step 2 of the SFT stage. You can run [prepare_preference_dataset.sh](./data_preparation_scripts/prepare_preference_dataset.sh) to prepare the preference data for reward model training. #### Step 3: Training -You can run [train_rm.sh](./examples/training_scripts/train_rm.sh) to start the reward model training. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options. +You can run [train_rm.sh](./training_scripts/train_rm.sh) to start the reward model training. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options. #### Features and Tricks in RM Training @@ -627,14 +627,14 @@ The second dataset--- pretrained dataset is optional, provide it if you want to ] ``` #### Step 2: Preprocessing -To prepare the prompt dataset for PPO training, simply run [prepare_prompt_dataset.sh](./examples/data_preparation_scripts/prepare_prompt_dataset.sh) +To prepare the prompt dataset for PPO training, simply run [prepare_prompt_dataset.sh](./data_preparation_scripts/prepare_prompt_dataset.sh) You can use the SFT dataset you prepared in the SFT stage or prepare a new one from different source for the ptx dataset. The ptx data is used to calculate ptx loss, which stabilizes the training according to the [InstructGPT paper](https://arxiv.org/pdf/2203.02155.pdf). #### Step 3: Training -You can run the [train_ppo.sh](./examples/training_scripts/train_ppo.sh) to start PPO training. Here are some unique arguments for PPO, please refer to the training configuration section for other training configuration. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options. +You can run the [train_ppo.sh](./training_scripts/train_ppo.sh) to start PPO training. Here are some unique arguments for PPO, please refer to the training configuration section for other training configuration. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options. ```bash @@ -718,7 +718,7 @@ For DPO training, you only need the preference dataset. Please follow the instru #### Step 2: Training -You can run the [train_dpo.sh](./examples/training_scripts/train_dpo.sh) to start DPO training. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options. Following the trend of recent research on DPO-like alignment methods, we added option for the user to choose from, including whether to do length normalization , reward shaping and whether to use a reference model in calculating implicit reward. Here are those options, +You can run the [train_dpo.sh](./training_scripts/train_dpo.sh) to start DPO training. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options. Following the trend of recent research on DPO-like alignment methods, we added option for the user to choose from, including whether to do length normalization , reward shaping and whether to use a reference model in calculating implicit reward. Here are those options, ``` --beta 0.1 \ # the temperature in DPO loss, Default to 0.1 @@ -735,7 +735,7 @@ You can run the [train_dpo.sh](./examples/training_scripts/train_dpo.sh) to star ### Alternative Option For RLHF: Simple Preference Optimization We support the method introduced in the paper [SimPO: Simple Preference Optimization -with a Reference-Free Reward](https://arxiv.org/pdf/2405.14734) (SimPO). Which is a reference model free aligment method that add length normalization and reward shaping to the DPO loss to enhance training stability and efficiency. As the method doesn't deviate too much from DPO, we add support for length normalization and SimPO reward shaping in our DPO implementation. To use SimPO in alignment, use the [train_dpo.sh](./examples/training_scripts/train_dpo.sh) script, set the `loss_type` to `simpo_loss`, you can also set the value for temperature (`beta`) and reward target margin (`gamma`) but it is optional. +with a Reference-Free Reward](https://arxiv.org/pdf/2405.14734) (SimPO). Which is a reference model free aligment method that add length normalization and reward shaping to the DPO loss to enhance training stability and efficiency. As the method doesn't deviate too much from DPO, we add support for length normalization and SimPO reward shaping in our DPO implementation. To use SimPO in alignment, use the [train_dpo.sh](./training_scripts/train_dpo.sh) script, set the `loss_type` to `simpo_loss`, you can also set the value for temperature (`beta`) and reward target margin (`gamma`) but it is optional. #### SimPO Result

@@ -744,7 +744,7 @@ with a Reference-Free Reward](https://arxiv.org/pdf/2405.14734) (SimPO). Which i ### Alternative Option For RLHF: Odds Ratio Preference Optimization -We support the method introduced in the paper [ORPO: Monolithic Preference Optimization without Reference Model](https://arxiv.org/abs/2403.07691) (ORPO). Which is a reference model free aligment method that mixes the SFT loss with a reinforcement learning loss that uses odds ratio as the implicit reward to enhance training stability and efficiency. Simply set the flag to disable the use of the reference model, set the reward target margin and enable length normalization in the DPO training script. To use ORPO in alignment, use the [train_orpo.sh](./examples/training_scripts/train_orpo.sh) script, You can set the value for `lambda` (which determine how strongly the reinforcement learning loss affect the training) but it is optional. +We support the method introduced in the paper [ORPO: Monolithic Preference Optimization without Reference Model](https://arxiv.org/abs/2403.07691) (ORPO). Which is a reference model free aligment method that mixes the SFT loss with a reinforcement learning loss that uses odds ratio as the implicit reward to enhance training stability and efficiency. Simply set the flag to disable the use of the reference model, set the reward target margin and enable length normalization in the DPO training script. To use ORPO in alignment, use the [train_orpo.sh](./training_scripts/train_orpo.sh) script, You can set the value for `lambda` (which determine how strongly the reinforcement learning loss affect the training) but it is optional. #### ORPO Result

diff --git a/applications/ColossalChat/examples/training_scripts/train_dpo.sh b/applications/ColossalChat/examples/training_scripts/train_dpo.sh index f7bb45658..082d54ff0 100755 --- a/applications/ColossalChat/examples/training_scripts/train_dpo.sh +++ b/applications/ColossalChat/examples/training_scripts/train_dpo.sh @@ -15,9 +15,8 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() { } set_n_least_used_CUDA_VISIBLE_DEVICES 4 -PROJECT_NAME="dpo" +PROJECT_NAME="DPO" PARENT_SAVE_DIR="" # Path to a folder to save checkpoints -PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs PARENT_CONFIG_FILE="" # Path to a folder to save training config logs PRETRAINED_MODEL_PATH="" # huggingface or local model path PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path @@ -38,11 +37,10 @@ declare -a dataset=( TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S) FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}" SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}" -CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json" +CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json" colossalai run --nproc_per_node 4 --hostfile hostfile --master_port 31313 train_dpo.py \ --pretrain $PRETRAINED_MODEL_PATH \ - --checkpoint_path $PRETRAINED_MODEL_PATH \ --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \ --dataset ${dataset[@]} \ --plugin "zero2" \ diff --git a/applications/ColossalChat/examples/training_scripts/train_orpo.sh b/applications/ColossalChat/examples/training_scripts/train_orpo.sh index ca80a14c1..482956b21 100755 --- a/applications/ColossalChat/examples/training_scripts/train_orpo.sh +++ b/applications/ColossalChat/examples/training_scripts/train_orpo.sh @@ -13,11 +13,10 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() { echo "Now CUDA_VISIBLE_DEVICES is set to:" echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" } -set_n_least_used_CUDA_VISIBLE_DEVICES 8 +set_n_least_used_CUDA_VISIBLE_DEVICES 2 -PROJECT_NAME="dpo" +PROJECT_NAME="ORPO" PARENT_SAVE_DIR="" # Path to a folder to save checkpoints -PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs PARENT_CONFIG_FILE="" # Path to a folder to save training config logs PRETRAINED_MODEL_PATH="" # huggingface or local model path PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path @@ -38,11 +37,10 @@ declare -a dataset=( TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S) FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}" SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}" -CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json" +CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json" colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 31313 train_orpo.py \ --pretrain $PRETRAINED_MODEL_PATH \ - --checkpoint_path $PRETRAINED_MODEL_PATH \ --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \ --dataset ${dataset[@]} \ --plugin "zero2" \ diff --git a/applications/ColossalChat/examples/training_scripts/train_ppo.sh b/applications/ColossalChat/examples/training_scripts/train_ppo.sh index 91633978e..277e75e6d 100755 --- a/applications/ColossalChat/examples/training_scripts/train_ppo.sh +++ b/applications/ColossalChat/examples/training_scripts/train_ppo.sh @@ -15,10 +15,9 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() { } set_n_least_used_CUDA_VISIBLE_DEVICES 8 -PROJECT_NAME="ppo" +PROJECT_NAME="PPO" PARENT_SAVE_DIR="" # Path to a folder to save checkpoints -PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs PARENT_CONFIG_FILE="" # Path to a folder to save training config logs PRETRAINED_MODEL_PATH="" # local pretrained model path (from RLHF step 1: SFT) PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path @@ -54,7 +53,7 @@ declare -a ptx_dataset=( TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S) FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}" SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}" -CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json" +CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json" colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 31312 train_ppo.py \ --pretrain $PRETRAINED_MODEL_PATH \ diff --git a/applications/ColossalChat/examples/training_scripts/train_rm.sh b/applications/ColossalChat/examples/training_scripts/train_rm.sh index e06d9092f..cd42afcc8 100755 --- a/applications/ColossalChat/examples/training_scripts/train_rm.sh +++ b/applications/ColossalChat/examples/training_scripts/train_rm.sh @@ -15,9 +15,8 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() { } set_n_least_used_CUDA_VISIBLE_DEVICES 8 -PROJECT_NAME="rm" +PROJECT_NAME="RM" PARENT_SAVE_DIR="" # Path to a folder to save checkpoints -PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs PARENT_CONFIG_FILE="" # Path to a folder to save training config logs PRETRAINED_MODEL_PATH="" # huggingface or local model path PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path @@ -38,11 +37,10 @@ declare -a dataset=( TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S) FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}" SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}" -CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json" +CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json" colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 31312 train_rm.py \ --pretrain $PRETRAINED_MODEL_PATH \ - --checkpoint_path /home/yeanbang/data/experiments/rm/hhh_aligh/ckptllama2-rm-2024-01-17-14-43-24/epoch-1_step-1317/modeling \ --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \ --dataset ${dataset[@]} \ --plugin "zero2" \ diff --git a/applications/ColossalChat/examples/training_scripts/train_sft.py b/applications/ColossalChat/examples/training_scripts/train_sft.py index fe1506559..b89cbeb91 100755 --- a/applications/ColossalChat/examples/training_scripts/train_sft.py +++ b/applications/ColossalChat/examples/training_scripts/train_sft.py @@ -61,7 +61,7 @@ def train(args): Default torch ddp plugin without any acceleration, for debugging purpose acceleration, for debugging purpose """ - plugin = TorchDDPPlugin(find_unused_parameters=True) + plugin = TorchDDPPlugin(find_unused_parameters=True if args.grad_checkpoint is False else False) elif args.plugin == "gemini": plugin = GeminiPlugin( precision=args.mixed_precision, diff --git a/applications/ColossalChat/examples/training_scripts/train_sft.sh b/applications/ColossalChat/examples/training_scripts/train_sft.sh index 18df09293..c7d38c1d8 100755 --- a/applications/ColossalChat/examples/training_scripts/train_sft.sh +++ b/applications/ColossalChat/examples/training_scripts/train_sft.sh @@ -14,9 +14,8 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() { } set_n_least_used_CUDA_VISIBLE_DEVICES 4 -PROJECT_NAME="sft" +PROJECT_NAME="SFT" PARENT_SAVE_DIR="" # Path to a folder to save checkpoints -PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs PARENT_CONFIG_FILE="" # Path to a folder to save training config logs PRETRAINED_MODEL_PATH="" # huggingface or local model path PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path @@ -36,7 +35,7 @@ declare -a dataset=( TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S) FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}" SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}" -CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json" +CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json" echo $(which colossalai) echo $(which python)