[chatgpt]Reward Model Training Process update (#3133)

* add normalize function to value_head in bloom rm

* add normalization to value_function in gpt_rm

* add normalization to value_head of opt_rm

* add Anthropic/hh-rlhf dataset

* Update __init__.py

* Add LogExpLoss in RM training

* Update __init__.py

* update rm trainer to use acc as target

* update example/train_rm

* Update train_rm.sh

* code style

* Update README.md

* Update README.md

* add rm test to ci

* fix tokenier

* fix typo

* change batchsize to avoid oom in ci

* Update test_ci.sh
pull/3159/head
BlueRum 2023-03-20 09:59:06 +08:00 committed by GitHub
parent 1e58d31bb7
commit 7548ca5a54
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 270 additions and 111 deletions

View File

@ -1,4 +1,4 @@
from .reward_dataset import RewardDataset
from .reward_dataset import RmStaticDataset, HhRlhfDataset
from .utils import is_rank_0
__all__ = ['RewardDataset', 'is_rank_0']
__all__ = ['RmStaticDataset', 'HhRlhfDataset','is_rank_0']

View File

@ -5,8 +5,8 @@ from tqdm import tqdm
from .utils import is_rank_0
class RewardDataset(Dataset):
# Dahaos/rm-static
class RmStaticDataset(Dataset):
"""
Dataset for reward model
@ -14,16 +14,21 @@ class RewardDataset(Dataset):
dataset: dataset for reward model
tokenizer: tokenizer for reward model
max_length: max length of input
special_token: special token at the end of sentence
"""
def __init__(self, dataset, tokenizer: Callable, max_length: int) -> None:
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
super().__init__()
self.chosen = []
self.reject = []
if special_token is None:
self.end_token = tokenizer.eos_token
else:
self.end_token = special_token
for data in tqdm(dataset, disable=not is_rank_0()):
prompt = data['prompt']
chosen = prompt + data['chosen'] + "<|endoftext|>"
chosen = prompt + data['chosen'] + self.end_token
chosen_token = tokenizer(chosen,
max_length=max_length,
padding="max_length",
@ -34,7 +39,57 @@ class RewardDataset(Dataset):
"attention_mask": chosen_token['attention_mask']
})
reject = prompt + data['rejected'] + "<|endoftext|>"
reject = prompt + data['rejected'] + self.end_token
reject_token = tokenizer(reject,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.reject.append({
"input_ids": reject_token['input_ids'],
"attention_mask": reject_token['attention_mask']
})
def __len__(self):
length = len(self.chosen)
return length
def __getitem__(self, idx):
return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][
"input_ids"], self.reject[idx]["attention_mask"]
# Anthropic/hh-rlhf
class HhRlhfDataset(Dataset):
"""
Dataset for reward model
Args:
dataset: dataset for reward model
tokenizer: tokenizer for reward model
max_length: max length of input
special_token: special token at the end of sentence
"""
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
super().__init__()
self.chosen = []
self.reject = []
if special_token is None:
self.end_token = tokenizer.eos_token
else:
self.end_token = special_token
for data in tqdm(dataset, disable=not is_rank_0()):
chosen = data['chosen'] + self.end_token
chosen_token = tokenizer(chosen,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.chosen.append({
"input_ids": chosen_token['input_ids'],
"attention_mask": chosen_token['attention_mask']
})
reject = data['rejected'] + self.end_token
reject_token = tokenizer(reject,
max_length=max_length,
padding="max_length",

View File

@ -1,4 +1,4 @@
from .base import Actor, Critic, RewardModel
from .loss import PairWiseLoss, PolicyLoss, PPOPtxActorLoss, ValueLoss
from .loss import PolicyLoss, PPOPtxActorLoss, ValueLoss, LogSigLoss, LogExpLoss
__all__ = ['Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'PairWiseLoss']
__all__ = ['Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'LogSigLoss', 'LogExpLoss']

View File

@ -33,4 +33,5 @@ class BLOOMRM(RewardModel):
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.hidden_size, 1)
value_head.weight.data.normal_(mean=0.0, std=1/(model.config.hidden_size + 1))
super().__init__(model, value_head, lora_rank, lora_train_bias)

View File

@ -35,4 +35,5 @@ class GPTRM(RewardModel):
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.n_embd, 1)
value_head.weight.data.normal_(mean=0.0, std=1/(model.config.n_embd + 1))
super().__init__(model, value_head, lora_rank, lora_train_bias)

View File

@ -93,13 +93,23 @@ class PPOPtxActorLoss(nn.Module):
return policy_loss + self.pretrain_coef * lm_loss
class PairWiseLoss(nn.Module):
class LogSigLoss(nn.Module):
"""
Pairwise Loss for Reward Model
Details: https://arxiv.org/abs/2203.02155
"""
def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor:
probs = torch.sigmoid(chosen_reward - reject_reward)
log_probs = torch.log(probs)
loss = -log_probs.mean()
return loss
class LogExpLoss(nn.Module):
"""
Pairwise Loss for Reward Model
Details: https://arxiv.org/abs/2204.05862
"""
def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor:
loss = torch.log(1 + torch.exp(reject_reward - chosen_reward)).mean()
return loss

View File

@ -34,4 +34,5 @@ class OPTRM(RewardModel):
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.word_embed_proj_dim, 1)
value_head.weight.data.normal_(mean=0.0, std=1/(model.config.word_embed_proj_dim + 1))
super().__init__(model, value_head, lora_rank, lora_train_bias)

View File

@ -1,13 +1,12 @@
from abc import ABC
import pandas as pd
import loralib as lora
import torch
from chatgpt.dataset import RewardDataset
from chatgpt.models.loss import PairWiseLoss
from torch.optim import Adam, Optimizer
from torch.utils.data import DataLoader
from datetime import datetime
from torch.optim import Optimizer, lr_scheduler
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from .strategies import Strategy
from .utils import is_rank_0
@ -20,11 +19,12 @@ class RewardModelTrainer(ABC):
model (torch.nn.Module): the model to train
strategy (Strategy): the strategy to use for training
optim(Optimizer): the optimizer to use for training
train_dataset (RewardDataset): the dataset to use for training
eval_dataset (RewardDataset): the dataset to use for evaluation
loss_fn (callable): the loss function to use for training
train_dataset (Dataset): the dataset to use for training
valid_dataset (Dataset): the dataset to use for validation
eval_dataset (Dataset): the dataset to use for evaluation
batch_size (int, defaults to 1): the batch size while training
max_epochs (int, defaults to 2): the number of epochs to train
optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer
"""
def __init__(
@ -32,24 +32,52 @@ class RewardModelTrainer(ABC):
model,
strategy: Strategy,
optim: Optimizer,
train_dataset: RewardDataset,
eval_dataset: RewardDataset,
loss_fn,
train_dataset: Dataset,
valid_dataset: Dataset,
eval_dataset: Dataset,
batch_size: int = 1,
max_epochs: int = 2,
max_epochs: int = 1,
) -> None:
super().__init__()
self.strategy = strategy
self.epochs = max_epochs
self.train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
self.eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size)
self.train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
self.valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)
self.eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=True)
self.model = strategy.setup_model(model)
if "DDP" in str(self.strategy):
self.model = self.model.module
self.loss_fn = PairWiseLoss()
self.loss_fn = loss_fn
self.optimizer = strategy.setup_optimizer(optim, self.model)
self.scheduler = lr_scheduler.CosineAnnealingLR(self.optimizer, self.train_dataloader.__len__()//100)
def fit(self, use_lora):
def eval_acc(self, dataloader):
dist = 0
on = 0
cnt = 0
self.model.eval()
with torch.no_grad():
for chosen_ids, c_mask, reject_ids, r_mask in dataloader:
chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())
r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
reject_reward = self.model(reject_ids, attention_mask=r_mask)
for i in range(len(chosen_reward)):
cnt += 1
if chosen_reward[i] > reject_reward[i]:
on += 1
dist += (chosen_reward - reject_reward).mean().item()
dist_mean = dist / len(dataloader)
acc = on / cnt
self.model.train()
return dist_mean, acc
def fit(self):
time = datetime.now()
epoch_bar = tqdm(range(self.epochs), desc='Train epoch', disable=not is_rank_0())
for epoch in range(self.epochs):
step_bar = tqdm(range(self.train_dataloader.__len__()),
@ -57,37 +85,36 @@ class RewardModelTrainer(ABC):
disable=not is_rank_0())
# train
self.model.train()
cnt = 0
acc = 0
dist = 0
for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader:
chosen_ids = chosen_ids.squeeze(1).cuda()
c_mask = c_mask.squeeze(1).cuda()
reject_ids = reject_ids.squeeze(1).cuda()
r_mask = r_mask.squeeze(1).cuda()
chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())
r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
reject_reward = self.model(reject_ids, attention_mask=r_mask)
loss = self.loss_fn(chosen_reward, reject_reward)
self.strategy.backward(loss, self.model, self.optimizer)
self.strategy.optimizer_step(self.optimizer)
self.optimizer.zero_grad()
cnt += 1
if cnt == 100:
self.scheduler.step()
dist, acc = self.eval_acc(self.valid_dataloader)
cnt = 0
if is_rank_0():
log = pd.DataFrame([[step_bar.n, loss.item(), dist, acc]], columns=['step', 'loss', 'dist', 'acc'])
log.to_csv('log_%s.csv' % time, mode='a', header=False, index=False)
step_bar.update()
step_bar.set_postfix({'loss': loss.item()})
step_bar.set_postfix({'dist': dist, 'acc': acc})
# eval
self.model.eval()
with torch.no_grad():
dist = 0
loss_sum = 0
for chosen_ids, c_mask, reject_ids, r_mask in self.eval_dataloader:
chosen_ids = chosen_ids.squeeze(1).cuda()
c_mask = c_mask.squeeze(1).cuda()
reject_ids = reject_ids.squeeze(1).cuda()
r_mask = r_mask.squeeze(1).cuda()
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
reject_reward = self.model(reject_ids, attention_mask=r_mask)
dist += (chosen_reward - reject_reward).mean().item()
loss = self.loss_fn(chosen_reward, reject_reward)
loss_sum += loss.item()
dist_mean = dist / self.eval_dataloader.__len__()
loss_mean = loss_sum / self.eval_dataloader.__len__()
dist, acc = self.eval_acc(self.eval_dataloader)
if is_rank_0():
log = pd.DataFrame([[step_bar.n, loss.item(), dist, acc]], columns=['step', 'loss', 'dist', 'acc'])
log.to_csv('log.csv', mode='a', header=False, index=False)
epoch_bar.update()
step_bar.set_postfix({'loss': loss_mean, 'dist_mean': dist_mean})
step_bar.set_postfix({'dist': dist, 'acc': acc})
step_bar.close()

View File

@ -7,26 +7,42 @@ pip install -r requirements.txt
```
## Train the reward model (Stage 2)
We use [rm-static](https://huggingface.co/datasets/Dahoas/rm-static) as dataset to train our reward model. It is a dataset of chosen & rejected response of the same prompt.
You can download the dataset from huggingface automatically.
Use these code to train your reward model.
```shell
# Naive reward model training
python train_reward_model.py --pretrain <your model path> --model <your model type> --strategy naive
# Take naive reward model training with opt-350m as example
python train_reward_model.py --pretrain "facebook/opt-350m" --model 'opt' --strategy naive
# use colossalai_zero2
torchrun --standalone --nproc_per_node=2 train_reward_model.py --pretrain <your model path> --model <your model type> --strategy colossalai_zero2
torchrun --standalone --nproc_per_node=2 train_reward_model.py --pretrain "facebook/opt-350m" --model 'opt' --strategy colossalai_zero2
```
### Features and tricks in RM training
- We support [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf)and[rm-static](https://huggingface.co/datasets/Dahoas/rm-static) datasets.
- We support 2 kinds of loss_function named 'log_sig'(used by OpenAI) and 'log_exp'(used by Anthropic).
- We change the loss to valid_acc and pair_dist to monitor progress during training.
- We add special token to the end of the sequence to get better result.
- We use cosine-reducing lr-scheduler for RM training.
- We set value_head as 1 liner layer and initialize the weight of value_head using N(01/(d_model + 1)) distribution.
- We train a Bloom-560m reward model for 1 epoch and find the test acc of the model achieve the performance mentions in [Anthropics paper](https://arxiv.org/abs/2112.00861).
### Experiment result
Model performance in [Anthropics paper](https://arxiv.org/abs/2112.00861):
<div align=center> <img width="512" alt="image" src="https://user-images.githubusercontent.com/70618399/225263321-8d64c3a8-6877-4cc8-9b61-0e1c52d3d94f.png">
<div align=left>Our training & test result of bloom-560m for 1 epoch:
<div align=center> <img width="512" alt="image" src="https://user-images.githubusercontent.com/70618399/225262950-a7f0a686-25de-44ec-98f2-11b83ea86674.png">
<div align=left>
## Train with dummy prompt data (Stage 3)
This script supports 3 strategies:
This script supports 4 kinds of strategies:
- naive
- ddp
- colossalai
- colossalai_zero2
- colossalai_gemini
It uses random generated prompt data.
@ -53,7 +69,7 @@ We use [awesome-chatgpt-prompts](https://huggingface.co/datasets/fka/awesome-cha
You should download `prompts.csv` first.
This script also supports 3 strategies.
This script also supports 4 strategies.
```shell
# display cli help
@ -75,6 +91,9 @@ python inference.py --model_path <your actor model path> --model <your model typ
python inference.py --model_path ./actor_checkpoint_prompts.pt --pretrain bigscience/bloom-560m --model bloom
```
## Attention
The examples is just a demo for testing our progress of RM and PPO training.
#### data
- [x] [rm-static](https://huggingface.co/datasets/Dahoas/rm-static)

View File

@ -69,3 +69,23 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py $PROMPT_PATH \
python ${BASE}/inference.py --model_path ${BASE}/actor_checkpoint_prompts.pt --pretrain 'gpt2' --model gpt2
rm -rf ${BASE}/actor_checkpoint_prompts.pt
# train rm
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
--pretrain 'facebook/opt-350m' --model 'opt' \
--strategy colossalai_zero2 --loss_fn 'log_sig'\
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base'\
--test True --lora_rank 4
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
--pretrain 'gpt2' --model 'gpt2' \
--strategy colossalai_gemini --loss_fn 'log_exp'\
--dataset 'Dahoas/rm-static' --test True --lora_rank 4
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
--pretrain 'bigscience/bloom-560m' --model 'bloom' \
--strategy colossalai_zero2 --loss_fn 'log_sig'\
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base'\
--test True --lora_rank 4
rm -rf ${BASE}/rm_ckpt.pt

View File

@ -2,7 +2,8 @@ import argparse
import loralib as lora
import torch
from chatgpt.dataset import RewardDataset
from chatgpt.dataset import HhRlhfDataset, RmStaticDataset
from chatgpt.models import LogSigLoss, LogExpLoss
from chatgpt.models.base import RewardModel
from chatgpt.models.bloom import BLOOMRM
from chatgpt.models.gpt import GPTRM
@ -10,13 +11,13 @@ from chatgpt.models.opt import OPTRM
from chatgpt.trainer import RewardModelTrainer
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from datasets import load_dataset
from random import randint
from torch.optim import Adam
from transformers import AutoTokenizer, BloomTokenizerFast
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from colossalai.nn.optimizer import HybridAdam
def train(args):
# configure strategy
if args.strategy == 'naive':
@ -33,57 +34,85 @@ def train(args):
# configure model
with strategy.model_init_context():
if args.model == 'bloom':
model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
elif args.model == 'opt':
model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
elif args.model == 'gpt2':
model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
else:
raise ValueError(f'Unsupported model "{args.model}"')
if args.model_path is not None:
state_dict = torch.load(args.model_path)
model.load_state_dict(state_dict)
# configure tokenizer
if args.model == 'gpt2':
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'bloom':
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
tokenizer.pad_token = tokenizer.eos_token
tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
elif args.model == 'opt':
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
else:
raise ValueError(f'Unsupported model "{args.model}"')
tokenizer.pad_token = tokenizer.eos_token
max_len = 512
max_len = args.max_len
# configure optimizer
if args.strategy.startswith('colossalai'):
optim = HybridAdam(model.parameters(), lr=5e-5)
optim = HybridAdam(model.parameters(), lr=1.5e-5)
else:
optim = Adam(model.parameters(), lr=5e-5)
optim = Adam(model.parameters(), lr=1.5e-5)
# configure loss function
if args.loss_fn == 'log_sig':
loss_fn = LogSigLoss()
elif args.loss_fn == 'log_exp':
loss_fn = LogExpLoss()
else:
raise ValueError(f'Unsupported loss function "{args.loss_fn}"')
# prepare for data and dataset
data = load_dataset(args.dataset)
train_data = data["train"]
eval_data = data['test']
train_dataset = RewardDataset(train_data, tokenizer, max_len)
eval_dataset = RewardDataset(eval_data, tokenizer, max_len)
if args.subset is not None:
data = load_dataset(args.dataset, data_dir=args.subset)
else:
data = load_dataset(args.dataset)
if args.test:
train_data = data['train'].select(range(100))
eval_data = data['test'].select(range(10))
else:
train_data = data['train']
eval_data = data['test']
valid_data = data['test'].select((randint(0, len(eval_data) - 1) for _ in range(len(eval_data)//10)))
if args.dataset == 'Dahoas/rm-static':
train_dataset = RmStaticDataset(train_data, tokenizer, max_len)
valid_dataset = RmStaticDataset(valid_data, tokenizer, max_len)
eval_dataset = RmStaticDataset(eval_data, tokenizer, max_len)
elif args.dataset == 'Anthropic/hh-rlhf':
train_dataset = HhRlhfDataset(train_data, tokenizer, max_len)
valid_dataset = HhRlhfDataset(valid_data, tokenizer, max_len)
eval_dataset = HhRlhfDataset(eval_data, tokenizer, max_len)
else:
raise ValueError(f'Unsupported dataset "{args.dataset}"')
trainer = RewardModelTrainer(model=model,
strategy=strategy,
optim=optim,
loss_fn = loss_fn,
train_dataset=train_dataset,
valid_dataset=valid_dataset,
eval_dataset=eval_dataset,
batch_size=args.batch_size,
max_epochs=args.max_epochs)
trainer.fit(use_lora=args.lora_rank)
trainer.fit()
# save model checkpoint after fitting on only rank0
strategy.save_model(model, 'rm_checkpoint.pt', only_rank0=True)
strategy.save_model(trainer.model, args.save_path, only_rank0=True)
# save optimizer checkpoint on all ranks
strategy.save_optimizer(optim, 'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()), only_rank0=False)
if args.need_optim_ckpt:
strategy.save_optimizer(trainer.optimizer, 'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()), only_rank0=False)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
@ -92,10 +121,18 @@ if __name__ == '__main__':
default='naive')
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt'], default='bloom')
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--dataset', type=str, default='Dahoas/rm-static')
parser.add_argument('--save_path', type=str, default='rm_ckpt.pth')
parser.add_argument('--model_path', type=str, default=None)
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
parser.add_argument('--dataset', type=str,
choices=['Anthropic/hh-rlhf', 'Dahoas/rm-static'],
default='Dahoas/rm-static')
parser.add_argument('--subset', type=str, default=None)
parser.add_argument('--save_path', type=str, default='rm_ckpt.pt')
parser.add_argument('--max_epochs', type=int, default=1)
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--max_len', type=int, default=512)
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument('--loss_fn', type=str, default='log_sig', choices=['log_sig', 'log_exp'])
parser.add_argument('--test', type=bool, default=False)
args = parser.parse_args()
train(args)

View File

@ -1,20 +1,8 @@
set_n_least_used_CUDA_VISIBLE_DEVICES() {
local n=${1:-"9999"}
echo "GPU Memory Usage:"
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
| tail -n +2 \
| nl -v 0 \
| tee /dev/tty \
| sort -g -k 2 \
| awk '{print $1}' \
| head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
set_n_least_used_CUDA_VISIBLE_DEVICES 1
set_n_least_used_CUDA_VISIBLE_DEVICES 2
# torchrun --standalone --nproc_per_node=2 train_reward_model.py --pretrain 'bigscience/bloomz-560m' --model 'bloom' --strategy colossalai_zero2
torchrun --standalone --nproc_per_node=2 train_reward_model.py --model 'gpt2' --strategy colossalai_zero2
# torchrun --standalone --nproc_per_node=2 train_reward_model.py --pretrain "facebook/opt-350m" --model 'opt' --strategy colossalai_zero2
python train_reward_model.py --pretrain '/home/lczht/data2/bloom-560m' \
--model 'bloom' \
--strategy naive \
--loss_fn 'log_exp'\
--save_path 'rmstatic.pt' \
--test True