2023-03-28 12:25:36 +00:00
|
|
|
#!/usr/bin/env bash
|
|
|
|
|
2023-06-25 09:36:21 +00:00
|
|
|
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
|
|
|
local n=${1:-"9999"}
|
|
|
|
echo "GPU Memory Usage:"
|
|
|
|
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
|
|
|
|
tail -n +2 |
|
|
|
|
nl -v 0 |
|
|
|
|
tee /dev/tty |
|
|
|
|
sort -g -k 2 |
|
|
|
|
awk '{print $1}' |
|
|
|
|
head -n $n)
|
|
|
|
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
|
|
|
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
|
|
|
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
|
|
|
}
|
|
|
|
|
|
|
|
set_n_least_used_CUDA_VISIBLE_DEVICES 4
|
|
|
|
|
2023-03-28 12:25:36 +00:00
|
|
|
set -xue
|
|
|
|
|
2023-03-22 09:18:13 +00:00
|
|
|
if [ -z "$SFT_DATASET" ]; then
|
|
|
|
echo "Please set \$SFT_DATASET to the path to sft dataset."
|
|
|
|
exit 1
|
|
|
|
fi
|
|
|
|
|
2023-03-28 12:25:36 +00:00
|
|
|
if [ -z "$PROMPT_PATH" ]; then
|
|
|
|
echo "Please set \$PROMPT_PATH to the path to prompts csv."
|
|
|
|
exit 1
|
|
|
|
fi
|
|
|
|
|
2023-03-22 09:18:13 +00:00
|
|
|
if [ -z "$PRETRAIN_DATASET" ]; then
|
|
|
|
echo "Please set \$PRETRAIN_DATASET to the path to alpaca data."
|
|
|
|
exit 1
|
|
|
|
fi
|
|
|
|
|
2023-03-28 12:25:36 +00:00
|
|
|
BASE=$(realpath $(dirname $0))
|
|
|
|
|
|
|
|
export OMP_NUM_THREADS=8
|
|
|
|
|
|
|
|
# install requirements
|
|
|
|
pip install -r ${BASE}/requirements.txt
|
|
|
|
|
2023-03-22 09:18:13 +00:00
|
|
|
wandb init -m offline
|
2023-03-28 12:25:36 +00:00
|
|
|
|
2023-07-04 05:49:09 +00:00
|
|
|
# FIXME: This is a hack to skip tests that are not working
|
2023-06-25 09:36:21 +00:00
|
|
|
# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
|
2023-07-04 05:49:09 +00:00
|
|
|
# - llama-*: These tests can be passed locally, skipped for long execution time
|
2023-06-25 09:36:21 +00:00
|
|
|
SKIPPED_TESTS=(
|
|
|
|
"gpt2-ddp"
|
2023-07-04 05:49:09 +00:00
|
|
|
"llama-ddp"
|
|
|
|
"llama-colossalai_gemini"
|
|
|
|
"llama-colossalai_zero2"
|
2023-06-25 09:36:21 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
# These tests are quick and do not have any dependencies
|
2023-07-04 05:49:09 +00:00
|
|
|
for model in 'gpt2' 'bloom' 'opt' 'llama'; do
|
2023-06-29 10:11:00 +00:00
|
|
|
for strategy in 'ddp' 'colossalai_gemini' 'colossalai_zero2'; do
|
2023-06-25 09:36:21 +00:00
|
|
|
if [[ " ${SKIPPED_TESTS[*]} " =~ " ${model}-${strategy} " ]]; then
|
|
|
|
echo "[Test]: Skipped $model-$strategy"
|
|
|
|
continue
|
|
|
|
fi
|
|
|
|
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
|
|
|
|
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
|
|
|
|
--strategy $strategy --model $model \
|
2023-06-29 02:48:09 +00:00
|
|
|
--num_episodes 1 --num_collect_steps 2 --num_update_steps 1 \
|
2023-07-04 05:49:09 +00:00
|
|
|
--train_batch_size 2 --lora_rank 4
|
2023-06-25 09:36:21 +00:00
|
|
|
done
|
|
|
|
done
|
|
|
|
|
2023-03-22 09:18:13 +00:00
|
|
|
# train sft
|
|
|
|
torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'bigscience/bloom-560m' \
|
2023-06-25 09:36:21 +00:00
|
|
|
--model 'bloom' --strategy colossalai_zero2 --lora_rank 4 \
|
|
|
|
--dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
|
|
|
|
--save_path ${BASE}/output
|
2023-04-27 07:37:38 +00:00
|
|
|
rm -rf ${BASE}/output
|
2023-04-03 02:11:03 +00:00
|
|
|
|
2023-03-22 09:18:13 +00:00
|
|
|
torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'gpt2' \
|
2023-06-25 09:36:21 +00:00
|
|
|
--model 'gpt2' --strategy colossalai_zero2 \
|
|
|
|
--dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
|
|
|
|
--save_path ${BASE}/output
|
2023-04-27 07:37:38 +00:00
|
|
|
rm -rf ${BASE}/output
|
2023-03-28 12:25:36 +00:00
|
|
|
|
2023-03-22 09:18:13 +00:00
|
|
|
torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'facebook/opt-350m' \
|
2023-06-25 09:36:21 +00:00
|
|
|
--model 'opt' --strategy colossalai_zero2 --lora_rank 4 \
|
|
|
|
--dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
|
|
|
|
--save_path ${BASE}/output
|
2023-04-27 07:37:38 +00:00
|
|
|
rm -rf ${BASE}/output
|
2023-03-28 12:25:36 +00:00
|
|
|
|
2023-03-22 09:18:13 +00:00
|
|
|
torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'gpt2' \
|
2023-06-25 09:36:21 +00:00
|
|
|
--model 'gpt2' --strategy ddp --lora_rank 4 \
|
|
|
|
--dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
|
|
|
|
--save_path ${BASE}/output
|
2023-03-22 09:18:13 +00:00
|
|
|
rm -rf ${BASE}/output
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
# train rm
|
|
|
|
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
2023-06-25 09:36:21 +00:00
|
|
|
--pretrain 'facebook/opt-350m' --model 'opt' \
|
|
|
|
--strategy colossalai_zero2 --loss_fn 'log_sig' \
|
|
|
|
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \
|
|
|
|
--test True --lora_rank 0 \
|
|
|
|
--save_path ${BASE}/rm_ckpt_opt.pt
|
2023-03-22 09:18:13 +00:00
|
|
|
|
|
|
|
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
2023-06-25 09:36:21 +00:00
|
|
|
--pretrain 'gpt2' --model 'gpt2' \
|
|
|
|
--strategy colossalai_zero2 --loss_fn 'log_exp' \
|
|
|
|
--dataset 'Dahoas/rm-static' \
|
|
|
|
--test True --lora_rank 0 \
|
|
|
|
--save_path ${BASE}/rm_ckpt_gpt.pt
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
2023-06-25 09:36:21 +00:00
|
|
|
--pretrain 'gpt2' --model 'gpt2' \
|
|
|
|
--strategy ddp --loss_fn 'log_exp' \
|
|
|
|
--dataset 'Dahoas/rm-static' \
|
|
|
|
--test True --lora_rank 4 \
|
|
|
|
--save_path ${BASE}/rm_ckpt.pt
|
2023-04-27 07:37:38 +00:00
|
|
|
rm -rf ${BASE}/rm_ckpt.pt
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
2023-06-25 09:36:21 +00:00
|
|
|
--pretrain 'bigscience/bloom-560m' --model 'bloom' \
|
|
|
|
--strategy colossalai_zero2 --loss_fn 'log_sig' \
|
|
|
|
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \
|
|
|
|
--test True --lora_rank 4 \
|
|
|
|
--save_path ${BASE}/rm_ckpt.pt
|
2023-04-27 07:37:38 +00:00
|
|
|
rm -rf ${BASE}/rm_ckpt.pt
|
2023-03-28 12:25:36 +00:00
|
|
|
|
2023-06-29 10:11:00 +00:00
|
|
|
# train rl
|
2023-06-25 09:36:21 +00:00
|
|
|
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
|
|
|
|
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
|
2023-06-29 02:48:09 +00:00
|
|
|
--strategy colossalai_zero2 --num_episodes 1 \
|
|
|
|
--num_collect_steps 2 --num_update_steps 1 --train_batch_size 2 \
|
2023-06-25 09:36:21 +00:00
|
|
|
--pretrain 'facebook/opt-350m' --model opt \
|
|
|
|
--rm_pretrain 'facebook/opt-350m' \
|
|
|
|
--rm_path ${BASE}/rm_ckpt_opt.pt \
|
|
|
|
--save_path ${BASE}/actor_checkpoint_prompts.pt
|
2023-03-22 09:18:13 +00:00
|
|
|
rm -rf ${BASE}/rm_ckpt_opt.pt
|
|
|
|
|
2023-06-25 09:36:21 +00:00
|
|
|
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
|
|
|
|
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
|
2023-06-29 02:48:09 +00:00
|
|
|
--strategy colossalai_zero2 --num_episodes 1 \
|
|
|
|
--num_collect_steps 2 --num_update_steps 1 --train_batch_size 2 \
|
2023-06-25 09:36:21 +00:00
|
|
|
--pretrain 'gpt2' --model gpt2 \
|
|
|
|
--rm_pretrain 'gpt2' \
|
|
|
|
--rm_path ${BASE}/rm_ckpt_gpt.pt \
|
|
|
|
--save_path ${BASE}/actor_checkpoint_prompts.pt
|
|
|
|
|
|
|
|
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
|
|
|
|
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
|
2023-06-29 02:48:09 +00:00
|
|
|
--strategy colossalai_gemini --num_episodes 1 \
|
|
|
|
--num_collect_steps 2 --num_update_steps 1 --train_batch_size 2 \
|
2023-06-25 09:36:21 +00:00
|
|
|
--pretrain 'gpt2' --model gpt2 \
|
|
|
|
--rm_pretrain 'gpt2' \
|
|
|
|
--rm_path ${BASE}/rm_ckpt_gpt.pt \
|
|
|
|
--save_path ${BASE}/actor_checkpoint_prompts.pt
|
2023-03-22 09:18:13 +00:00
|
|
|
rm -rf ${BASE}/rm_ckpt_gpt.pt
|
|
|
|
|
2023-04-27 07:37:38 +00:00
|
|
|
rm -rf ${BASE}/actor_checkpoint_prompts.pt
|
2023-06-07 02:41:16 +00:00
|
|
|
|
|
|
|
# 3080 doesn't support P2P, skip this test
|
|
|
|
# cd ${BASE}/ray && bash test_ci.sh && cd ${BASE}
|