mirror of https://github.com/hpcaitech/ColossalAI
parent
82149e9d1b
commit
c9e27f0d1b
|
@ -23,7 +23,7 @@ class RewardModel(LoRAModule):
|
||||||
lora_rank: int = 0,
|
lora_rank: int = 0,
|
||||||
lora_train_bias: str = 'none') -> None:
|
lora_train_bias: str = 'none') -> None:
|
||||||
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
|
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
|
||||||
self.body = model
|
self.model = model
|
||||||
if value_head is not None:
|
if value_head is not None:
|
||||||
if value_head.out_features != 1:
|
if value_head.out_features != 1:
|
||||||
raise ValueError("The value head of reward model's output dim should be 1!")
|
raise ValueError("The value head of reward model's output dim should be 1!")
|
||||||
|
@ -34,7 +34,7 @@ class RewardModel(LoRAModule):
|
||||||
self.convert_to_lora()
|
self.convert_to_lora()
|
||||||
|
|
||||||
def forward(self, sequences: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
def forward(self, sequences: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
outputs = self.body(sequences, attention_mask=attention_mask)
|
outputs = self.model(sequences, attention_mask=attention_mask)
|
||||||
last_hidden_states = outputs['last_hidden_state']
|
last_hidden_states = outputs['last_hidden_state']
|
||||||
values = self.value_head(last_hidden_states)[:, :-1]
|
values = self.value_head(last_hidden_states)[:, :-1]
|
||||||
value = values.mean(dim=1).squeeze(1) # ensure shape is (B)
|
value = values.mean(dim=1).squeeze(1) # ensure shape is (B)
|
||||||
|
|
|
@ -44,6 +44,8 @@ class RewardModelTrainer(ABC):
|
||||||
self.eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size)
|
self.eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size)
|
||||||
|
|
||||||
self.model = strategy.setup_model(model)
|
self.model = strategy.setup_model(model)
|
||||||
|
if "DDP" in str(self.strategy):
|
||||||
|
self.model = self.model.module
|
||||||
self.loss_fn = PairWiseLoss()
|
self.loss_fn = PairWiseLoss()
|
||||||
self.optimizer = strategy.setup_optimizer(optim, self.model)
|
self.optimizer = strategy.setup_optimizer(optim, self.model)
|
||||||
|
|
||||||
|
@ -56,7 +58,7 @@ class RewardModelTrainer(ABC):
|
||||||
# train
|
# train
|
||||||
if use_lora > 0:
|
if use_lora > 0:
|
||||||
print("Using Lora")
|
print("Using Lora")
|
||||||
lora.mark_only_lora_as_trainable(self.model.body)
|
lora.mark_only_lora_as_trainable(self.model.model)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.model.train()
|
self.model.train()
|
||||||
|
|
|
@ -61,8 +61,8 @@ def train(args):
|
||||||
|
|
||||||
# prepare for data and dataset
|
# prepare for data and dataset
|
||||||
data = load_dataset(args.dataset)
|
data = load_dataset(args.dataset)
|
||||||
train_data = data["train"].select(range(100))
|
train_data = data["train"]
|
||||||
eval_data = data['test'].select(range(5))
|
eval_data = data['test']
|
||||||
train_dataset = RewardDataset(train_data, tokenizer, max_len)
|
train_dataset = RewardDataset(train_data, tokenizer, max_len)
|
||||||
eval_dataset = RewardDataset(eval_data, tokenizer, max_len)
|
eval_dataset = RewardDataset(eval_data, tokenizer, max_len)
|
||||||
|
|
||||||
|
@ -93,7 +93,7 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--pretrain', type=str, default=None)
|
parser.add_argument('--pretrain', type=str, default=None)
|
||||||
parser.add_argument('--dataset', type=str, default='Dahoas/rm-static')
|
parser.add_argument('--dataset', type=str, default='Dahoas/rm-static')
|
||||||
parser.add_argument('--save_path', type=str, default='rm_ckpt.pth')
|
parser.add_argument('--save_path', type=str, default='rm_ckpt.pth')
|
||||||
parser.add_argument('--max_epochs', type=int, default=10)
|
parser.add_argument('--max_epochs', type=int, default=1)
|
||||||
parser.add_argument('--batch_size', type=int, default=4)
|
parser.add_argument('--batch_size', type=int, default=4)
|
||||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
Loading…
Reference in New Issue