mirror of https://github.com/hpcaitech/ColossalAI
[Coati] Train DPO using PP (#6054)
* update dpo * remove unsupport plugin * update msg * update dpo * remove unsupport plugin * update msg * update template * update dataset * add pp for dpo * update dpo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add dpo fn * update dpo * update dpo * update dpo * update dpo * minor update * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update loss * update help * polish code --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/6118/merge
parent
dc2cdaf3e8
commit
4c8e85ee0d
|
@ -153,10 +153,11 @@ class DpoLoss(nn.Module):
|
|||
else:
|
||||
# If no reference model is provided
|
||||
ref_logratios = 0.0
|
||||
|
||||
pi_logratios = logprob_actor_chosen.sum(-1) - logprob_actor_reject.sum(-1)
|
||||
logits = pi_logratios - ref_logratios - self.gamma / self.beta
|
||||
losses = -torch.nn.functional.logsigmoid(self.beta * logits)
|
||||
|
||||
loss = losses.mean()
|
||||
# Calculate rewards for logging
|
||||
if logprob_ref_chosen is not None:
|
||||
chosen_rewards = self.beta * (logprob_actor_chosen.sum(-1) - logprob_ref_chosen.sum(-1)).detach()
|
||||
|
@ -167,7 +168,7 @@ class DpoLoss(nn.Module):
|
|||
else:
|
||||
rejected_rewards = self.beta * logprob_actor_reject.sum(-1).detach()
|
||||
|
||||
return losses, chosen_rewards, rejected_rewards
|
||||
return loss, chosen_rewards, rejected_rewards
|
||||
|
||||
|
||||
class LogSigLoss(nn.Module):
|
||||
|
|
|
@ -50,8 +50,8 @@ def _log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.
|
|||
torch.Tensor: The log probabilities corresponding to the labels.
|
||||
"""
|
||||
log_probs = F.log_softmax(logits, dim=-1)
|
||||
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
|
||||
return log_probs_labels.squeeze(-1)
|
||||
per_label_logps = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
|
||||
return per_label_logps.squeeze(-1)
|
||||
|
||||
|
||||
def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor:
|
||||
|
|
|
@ -6,6 +6,7 @@ import os
|
|||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from coati.models.loss import DpoLoss
|
||||
from coati.models.utils import calc_masked_log_probs
|
||||
from coati.trainer.utils import all_reduce_mean
|
||||
|
@ -13,10 +14,11 @@ from coati.utils import AccumulativeMeanMeter, save_checkpoint
|
|||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import trange
|
||||
from tqdm import tqdm, trange
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from colossalai.booster import Booster, Plugin
|
||||
from colossalai.booster.plugin import HybridParallelPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
@ -96,18 +98,25 @@ class DPOTrainer(SLTrainer):
|
|||
self.train_dataloader = train_preference_dataloader
|
||||
self.eval_dataloader = eval_preference_dataloader
|
||||
self.writer = None
|
||||
if use_wandb and is_rank_0():
|
||||
|
||||
init_criterion = (
|
||||
dist.get_rank() == dist.get_world_size() - 1
|
||||
if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1
|
||||
else is_rank_0()
|
||||
)
|
||||
|
||||
if use_wandb and init_criterion:
|
||||
assert log_dir is not None, "log_dir must be provided when use_wandb is True"
|
||||
import wandb
|
||||
|
||||
self.wandb_run = wandb.init(project="Coati-dpo", sync_tensorboard=True)
|
||||
if log_dir is not None and is_rank_0():
|
||||
if log_dir is not None and init_criterion:
|
||||
import os
|
||||
import time
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
log_dir = os.path.join(log_dir, "dpo")
|
||||
log_dir = os.path.join(log_dir, "DPO")
|
||||
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
|
||||
self.writer = SummaryWriter(log_dir=log_dir)
|
||||
|
||||
|
@ -117,166 +126,147 @@ class DPOTrainer(SLTrainer):
|
|||
epoch int: the number of current epoch
|
||||
"""
|
||||
self.model.train()
|
||||
self.accumulative_meter.reset()
|
||||
step_bar = trange(
|
||||
len(self.train_dataloader) // self.accumulation_steps,
|
||||
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
|
||||
disable=not is_rank_0(),
|
||||
)
|
||||
for i, batch in enumerate(self.train_dataloader):
|
||||
batch = to_device(batch, self.device)
|
||||
(
|
||||
chosen_input_ids,
|
||||
chosen_attention_mask,
|
||||
chosen_loss_mask,
|
||||
reject_input_ids,
|
||||
reject_attention_mask,
|
||||
reject_loss_mask,
|
||||
) = (
|
||||
batch["chosen_input_ids"],
|
||||
batch["chosen_attention_mask"],
|
||||
batch["chosen_loss_mask"],
|
||||
batch["reject_input_ids"],
|
||||
batch["reject_attention_mask"],
|
||||
batch["reject_loss_mask"],
|
||||
if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1:
|
||||
step_bar = tqdm(
|
||||
range(len(self.train_dataloader)),
|
||||
desc="Step",
|
||||
disable=not (dist.get_rank() == dist.get_world_size() - 1),
|
||||
)
|
||||
if not self.apply_loss_mask:
|
||||
chosen_loss_mask = chosen_loss_mask.fill_(1.0)
|
||||
reject_loss_mask = reject_loss_mask.fill_(1.0)
|
||||
for i, batch in enumerate(self.train_dataloader):
|
||||
batch = to_device(batch, self.device)
|
||||
(
|
||||
chosen_input_ids,
|
||||
chosen_attention_mask,
|
||||
chosen_loss_mask,
|
||||
reject_input_ids,
|
||||
reject_attention_mask,
|
||||
reject_loss_mask,
|
||||
) = (
|
||||
batch["chosen_input_ids"],
|
||||
batch["chosen_attention_mask"],
|
||||
batch["chosen_loss_mask"],
|
||||
batch["reject_input_ids"],
|
||||
batch["reject_attention_mask"],
|
||||
batch["reject_loss_mask"],
|
||||
)
|
||||
batch_size = chosen_input_ids.size()[0]
|
||||
# Calculate logits from reference model.
|
||||
if self.ref_model is not None:
|
||||
self.ref_model.eval()
|
||||
with torch.no_grad():
|
||||
ref_all_logits = self.ref_model(
|
||||
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
|
||||
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||
)["logits"]
|
||||
ref_chosen_logits = ref_all_logits[:batch_size]
|
||||
ref_reject_logits = ref_all_logits[batch_size:]
|
||||
logprob_ref_chosen = calc_masked_log_probs(
|
||||
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
|
||||
)
|
||||
logprob_ref_reject = calc_masked_log_probs(
|
||||
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
|
||||
)
|
||||
else:
|
||||
logprob_ref_chosen = None
|
||||
logprob_ref_reject = None
|
||||
|
||||
batch_size = chosen_input_ids.size()[0]
|
||||
# Merge chosen and reject
|
||||
inputs_ids = torch.stack([item for tup in zip(chosen_input_ids, reject_input_ids) for item in tup])
|
||||
attention_mask = torch.stack(
|
||||
[item for tup in zip(chosen_attention_mask, reject_attention_mask) for item in tup]
|
||||
)
|
||||
loss_mask = torch.stack([item for tup in zip(chosen_loss_mask, reject_loss_mask) for item in tup])
|
||||
logprob_ref = torch.stack([item for tup in zip(logprob_ref_chosen, logprob_ref_reject) for item in tup])
|
||||
|
||||
actor_all_logits = self.model(
|
||||
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
|
||||
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||
)["logits"]
|
||||
actor_chosen_logits = actor_all_logits[:batch_size]
|
||||
actor_reject_logits = actor_all_logits[batch_size:]
|
||||
logprob_actor_chosen = calc_masked_log_probs(
|
||||
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
|
||||
)
|
||||
data_iter = iter(
|
||||
[
|
||||
{
|
||||
"input_ids": inputs_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"loss_mask": loss_mask,
|
||||
"logprob_ref": logprob_ref,
|
||||
}
|
||||
]
|
||||
)
|
||||
rewards = []
|
||||
|
||||
logprob_actor_reject = calc_masked_log_probs(
|
||||
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
|
||||
)
|
||||
|
||||
if self.ref_model is not None:
|
||||
self.ref_model.eval()
|
||||
with torch.no_grad():
|
||||
ref_all_logits = self.ref_model(
|
||||
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
|
||||
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||
)["logits"]
|
||||
ref_chosen_logits = ref_all_logits[:batch_size]
|
||||
ref_reject_logits = ref_all_logits[batch_size:]
|
||||
logprob_ref_chosen = calc_masked_log_probs(
|
||||
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
|
||||
def _criterion(outputs, inputs):
|
||||
loss, chosen_rewards, rejected_rewards = self.actor_loss_fn(
|
||||
calc_masked_log_probs(
|
||||
outputs["logits"][0::2],
|
||||
inputs["input_ids"][0::2],
|
||||
inputs["loss_mask"][0::2][:, 1:],
|
||||
self.length_normalization,
|
||||
),
|
||||
calc_masked_log_probs(
|
||||
outputs["logits"][1::2],
|
||||
inputs["input_ids"][1::2],
|
||||
inputs["loss_mask"][1::2][:, 1:],
|
||||
self.length_normalization,
|
||||
),
|
||||
inputs["logprob_ref"][0::2] if inputs["logprob_ref"] is not None else None,
|
||||
inputs["logprob_ref"][1::2] if inputs["logprob_ref"] is not None else None,
|
||||
inputs["loss_mask"][0::2][:, 1:],
|
||||
inputs["loss_mask"][1::2][:, 1:],
|
||||
)
|
||||
logprob_ref_reject = calc_masked_log_probs(
|
||||
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
|
||||
)
|
||||
else:
|
||||
logprob_ref_chosen = None
|
||||
logprob_ref_reject = None
|
||||
rewards.append(chosen_rewards)
|
||||
rewards.append(rejected_rewards)
|
||||
return loss
|
||||
|
||||
losses, chosen_rewards, rejected_rewards = self.actor_loss_fn(
|
||||
logprob_actor_chosen,
|
||||
logprob_actor_reject,
|
||||
logprob_ref_chosen if logprob_ref_chosen is not None else None,
|
||||
logprob_ref_reject if logprob_ref_reject is not None else None,
|
||||
chosen_loss_mask[:, 1:],
|
||||
reject_loss_mask[:, 1:],
|
||||
)
|
||||
reward_accuracies = (chosen_rewards > rejected_rewards).float().mean()
|
||||
outputs = self.booster.execute_pipeline(
|
||||
data_iter,
|
||||
self.model,
|
||||
criterion=_criterion,
|
||||
optimizer=self.optimizer,
|
||||
return_loss=True,
|
||||
)
|
||||
loss = outputs["loss"]
|
||||
if self.booster.plugin.stage_manager.is_last_stage():
|
||||
chosen_rewards, rejected_rewards = rewards[0], rewards[1]
|
||||
global_loss = all_reduce_mean(loss, self.plugin)
|
||||
if dist.get_rank() == dist.get_world_size() - 1:
|
||||
step_bar.set_postfix(
|
||||
{
|
||||
"train/loss": global_loss.item(),
|
||||
"train/lr": self.actor_scheduler.get_last_lr()[0],
|
||||
"train/chosen_rewards": chosen_rewards.to(torch.float16).mean().item(),
|
||||
"train/rejected_rewards": rejected_rewards.to(torch.float16).mean().item(),
|
||||
}
|
||||
)
|
||||
step_bar.update()
|
||||
self.accumulative_meter.add("loss", global_loss.item())
|
||||
self.accumulative_meter.add("chosen_rewards", chosen_rewards.to(torch.float16).mean().item())
|
||||
self.accumulative_meter.add(
|
||||
"rejected_rewards", rejected_rewards.to(torch.float16).mean().item()
|
||||
)
|
||||
if self.writer is not None:
|
||||
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), i)
|
||||
self.writer.add_scalar(
|
||||
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), i
|
||||
)
|
||||
self.writer.add_scalar(
|
||||
"train/rejected_rewards",
|
||||
self.accumulative_meter.get("rejected_rewards"),
|
||||
i,
|
||||
)
|
||||
self.writer.add_scalar(
|
||||
"train/margin",
|
||||
self.accumulative_meter.get("chosen_rewards")
|
||||
- self.accumulative_meter.get("rejected_rewards"),
|
||||
i,
|
||||
)
|
||||
|
||||
# DPO Loss
|
||||
loss = losses.mean()
|
||||
|
||||
self.booster.backward(loss=loss, optimizer=self.optimizer)
|
||||
if self.num_train_step % self.accumulation_steps == self.accumulation_steps - 1:
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
self.actor_scheduler.step()
|
||||
|
||||
# sync
|
||||
loss_mean = all_reduce_mean(tensor=loss)
|
||||
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards)
|
||||
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards)
|
||||
reward_accuracies_mean = all_reduce_mean(tensor=reward_accuracies)
|
||||
self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
|
||||
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
|
||||
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
|
||||
self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
|
||||
|
||||
if i % self.accumulation_steps == self.accumulation_steps - 1:
|
||||
self.num_train_step += 1
|
||||
step_bar.update()
|
||||
# logging
|
||||
if self.writer and is_rank_0():
|
||||
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
|
||||
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
|
||||
self.writer.add_scalar(
|
||||
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step
|
||||
)
|
||||
self.writer.add_scalar(
|
||||
"train/rejected_rewards",
|
||||
self.accumulative_meter.get("rejected_rewards"),
|
||||
self.num_train_step,
|
||||
)
|
||||
self.writer.add_scalar(
|
||||
"train/margin",
|
||||
self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"),
|
||||
self.num_train_step,
|
||||
)
|
||||
self.writer.add_scalar(
|
||||
"train/accuracy",
|
||||
self.accumulative_meter.get("accuracy"),
|
||||
self.num_train_step,
|
||||
)
|
||||
self.accumulative_meter.reset()
|
||||
|
||||
if self.save_dir is not None and (self.num_train_step + 1) % self.save_interval == 0:
|
||||
# save checkpoint
|
||||
self.coordinator.print_on_master("\nStart saving model checkpoint with running states")
|
||||
save_checkpoint(
|
||||
save_dir=self.save_dir,
|
||||
booster=self.booster,
|
||||
model=self.model,
|
||||
optimizer=self.optimizer,
|
||||
lr_scheduler=self.actor_scheduler,
|
||||
epoch=epoch,
|
||||
step=i + 1,
|
||||
batch_size=batch_size,
|
||||
coordinator=self.coordinator,
|
||||
)
|
||||
self.coordinator.print_on_master(
|
||||
f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}"
|
||||
)
|
||||
|
||||
step_bar.close()
|
||||
|
||||
def _eval(self, epoch: int):
|
||||
"""
|
||||
Args:
|
||||
epoch int: the number of current epoch
|
||||
"""
|
||||
if self.eval_dataloader is None:
|
||||
self.coordinator.print_on_master("No eval dataloader is provided, skip evaluation")
|
||||
return
|
||||
self.model.eval()
|
||||
self.ref_model.eval()
|
||||
self.coordinator.print_on_master("\nStart evaluation...")
|
||||
|
||||
step_bar = trange(
|
||||
len(self.eval_dataloader),
|
||||
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
|
||||
disable=not is_rank_0(),
|
||||
)
|
||||
|
||||
self.accumulative_meter.reset()
|
||||
|
||||
with torch.no_grad():
|
||||
for i, batch in enumerate(self.eval_dataloader):
|
||||
else:
|
||||
self.accumulative_meter.reset()
|
||||
step_bar = trange(
|
||||
len(self.train_dataloader) // self.accumulation_steps,
|
||||
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
|
||||
disable=not is_rank_0(),
|
||||
)
|
||||
for i, batch in enumerate(self.train_dataloader):
|
||||
batch = to_device(batch, self.device)
|
||||
(
|
||||
chosen_input_ids,
|
||||
|
@ -300,12 +290,11 @@ class DPOTrainer(SLTrainer):
|
|||
batch_size = chosen_input_ids.size()[0]
|
||||
|
||||
actor_all_logits = self.model(
|
||||
torch.cat([chosen_input_ids, reject_input_ids]),
|
||||
torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
|
||||
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||
)["logits"]
|
||||
actor_chosen_logits = actor_all_logits[:batch_size]
|
||||
actor_reject_logits = actor_all_logits[batch_size:]
|
||||
|
||||
logprob_actor_chosen = calc_masked_log_probs(
|
||||
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
|
||||
)
|
||||
|
@ -314,22 +303,26 @@ class DPOTrainer(SLTrainer):
|
|||
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
|
||||
)
|
||||
|
||||
self.ref_model.eval()
|
||||
if self.ref_model is not None:
|
||||
self.ref_model.eval()
|
||||
with torch.no_grad():
|
||||
ref_all_logits = self.ref_model(
|
||||
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
|
||||
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||
)["logits"]
|
||||
ref_chosen_logits = ref_all_logits[:batch_size]
|
||||
ref_reject_logits = ref_all_logits[batch_size:]
|
||||
logprob_ref_chosen = calc_masked_log_probs(
|
||||
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
|
||||
)
|
||||
logprob_ref_reject = calc_masked_log_probs(
|
||||
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
|
||||
)
|
||||
else:
|
||||
logprob_ref_chosen = None
|
||||
logprob_ref_reject = None
|
||||
|
||||
ref_all_logits = self.ref_model(
|
||||
torch.cat([chosen_input_ids, reject_input_ids]),
|
||||
torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||
)["logits"]
|
||||
ref_chosen_logits = ref_all_logits[:batch_size]
|
||||
ref_reject_logits = ref_all_logits[batch_size:]
|
||||
logprob_ref_chosen = calc_masked_log_probs(
|
||||
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
|
||||
)
|
||||
logprob_ref_reject = calc_masked_log_probs(
|
||||
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
|
||||
)
|
||||
|
||||
losses, chosen_rewards, rejected_rewards = self.actor_loss_fn(
|
||||
loss, chosen_rewards, rejected_rewards = self.actor_loss_fn(
|
||||
logprob_actor_chosen,
|
||||
logprob_actor_reject,
|
||||
logprob_ref_chosen if logprob_ref_chosen is not None else None,
|
||||
|
@ -338,7 +331,9 @@ class DPOTrainer(SLTrainer):
|
|||
reject_loss_mask[:, 1:],
|
||||
)
|
||||
reward_accuracies = (chosen_rewards > rejected_rewards).float().mean()
|
||||
loss = losses.mean()
|
||||
|
||||
self.booster.backward(loss=loss, optimizer=self.optimizer)
|
||||
# sync
|
||||
loss_mean = all_reduce_mean(tensor=loss)
|
||||
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards)
|
||||
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards)
|
||||
|
@ -347,16 +342,301 @@ class DPOTrainer(SLTrainer):
|
|||
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
|
||||
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
|
||||
self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
|
||||
self.accumulative_meter.add(
|
||||
"margin", (chosen_rewards_mean - rejected_rewards_mean).to(torch.float16).mean().item()
|
||||
)
|
||||
step_bar.update()
|
||||
|
||||
msg = "Evaluation Result:\n"
|
||||
for tag in ["loss", "chosen_rewards", "rejected_rewards", "accuracy", "margin"]:
|
||||
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
|
||||
self.coordinator.print_on_master(msg)
|
||||
os.makedirs(self.save_dir, exist_ok=True)
|
||||
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
|
||||
f.write(msg)
|
||||
if (i + 1) % self.accumulation_steps == 0:
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
self.actor_scheduler.step()
|
||||
|
||||
step_bar.set_postfix(
|
||||
{
|
||||
"train/loss": self.accumulative_meter.get("loss"),
|
||||
"train/chosen_rewards": self.accumulative_meter.get("chosen_rewards"),
|
||||
"train/rejected_rewards": self.accumulative_meter.get("rejected_rewards"),
|
||||
"train/accuracy": self.accumulative_meter.get("accuracy"),
|
||||
}
|
||||
)
|
||||
step_bar.update()
|
||||
if self.writer and is_rank_0():
|
||||
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
|
||||
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
|
||||
self.writer.add_scalar(
|
||||
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step
|
||||
)
|
||||
self.writer.add_scalar(
|
||||
"train/rejected_rewards",
|
||||
self.accumulative_meter.get("rejected_rewards"),
|
||||
self.num_train_step,
|
||||
)
|
||||
self.writer.add_scalar(
|
||||
"train/margin",
|
||||
self.accumulative_meter.get("chosen_rewards")
|
||||
- self.accumulative_meter.get("rejected_rewards"),
|
||||
self.num_train_step,
|
||||
)
|
||||
self.writer.add_scalar(
|
||||
"train/accuracy",
|
||||
self.accumulative_meter.get("accuracy"),
|
||||
self.num_train_step,
|
||||
)
|
||||
self.num_train_step += 1
|
||||
self.accumulative_meter.reset()
|
||||
|
||||
if self.save_dir is not None and self.num_train_step > 0 and self.num_train_step % self.save_interval == 0:
|
||||
# save checkpoint
|
||||
self.coordinator.print_on_master("\nStart saving model checkpoint with running states")
|
||||
save_checkpoint(
|
||||
save_dir=self.save_dir,
|
||||
booster=self.booster,
|
||||
model=self.model,
|
||||
optimizer=self.optimizer,
|
||||
lr_scheduler=self.actor_scheduler,
|
||||
epoch=epoch,
|
||||
step=self.num_train_step,
|
||||
batch_size=batch_size,
|
||||
coordinator=self.coordinator,
|
||||
)
|
||||
self.coordinator.print_on_master(
|
||||
f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}"
|
||||
)
|
||||
|
||||
step_bar.close()
|
||||
|
||||
def _eval(self, epoch: int):
|
||||
"""
|
||||
Args:
|
||||
epoch int: the number of current epoch
|
||||
"""
|
||||
if self.eval_dataloader is None:
|
||||
self.coordinator.print_on_master("No eval dataloader is provided, skip evaluation")
|
||||
return
|
||||
self.model.eval()
|
||||
self.ref_model.eval()
|
||||
self.accumulative_meter.reset()
|
||||
self.coordinator.print_on_master("\nStart evaluation...")
|
||||
|
||||
if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1:
|
||||
step_bar = tqdm(
|
||||
range(len(self.eval_dataloader)),
|
||||
desc="Step",
|
||||
disable=not (dist.get_rank() == dist.get_world_size() - 1),
|
||||
)
|
||||
with torch.no_grad():
|
||||
for _, batch in enumerate(self.eval_dataloader):
|
||||
batch = to_device(batch, self.device)
|
||||
(
|
||||
chosen_input_ids,
|
||||
chosen_attention_mask,
|
||||
chosen_loss_mask,
|
||||
reject_input_ids,
|
||||
reject_attention_mask,
|
||||
reject_loss_mask,
|
||||
) = (
|
||||
batch["chosen_input_ids"],
|
||||
batch["chosen_attention_mask"],
|
||||
batch["chosen_loss_mask"],
|
||||
batch["reject_input_ids"],
|
||||
batch["reject_attention_mask"],
|
||||
batch["reject_loss_mask"],
|
||||
)
|
||||
batch_size = chosen_input_ids.size()[0]
|
||||
# Calculate logits from reference model.
|
||||
if self.ref_model is not None:
|
||||
self.ref_model.eval()
|
||||
with torch.no_grad():
|
||||
ref_all_logits = self.ref_model(
|
||||
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
|
||||
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||
)["logits"]
|
||||
ref_chosen_logits = ref_all_logits[:batch_size]
|
||||
ref_reject_logits = ref_all_logits[batch_size:]
|
||||
logprob_ref_chosen = calc_masked_log_probs(
|
||||
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
|
||||
)
|
||||
logprob_ref_reject = calc_masked_log_probs(
|
||||
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
|
||||
)
|
||||
else:
|
||||
logprob_ref_chosen = None
|
||||
logprob_ref_reject = None
|
||||
|
||||
# Merge chosen and reject
|
||||
inputs_ids = torch.stack([item for tup in zip(chosen_input_ids, reject_input_ids) for item in tup])
|
||||
attention_mask = torch.stack(
|
||||
[item for tup in zip(chosen_attention_mask, reject_attention_mask) for item in tup]
|
||||
)
|
||||
loss_mask = torch.stack([item for tup in zip(chosen_loss_mask, reject_loss_mask) for item in tup])
|
||||
logprob_ref = torch.stack(
|
||||
[item for tup in zip(logprob_ref_chosen, logprob_ref_reject) for item in tup]
|
||||
)
|
||||
|
||||
data_iter = iter(
|
||||
[
|
||||
{
|
||||
"input_ids": inputs_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"loss_mask": loss_mask,
|
||||
"logprob_ref": logprob_ref,
|
||||
}
|
||||
]
|
||||
)
|
||||
rewards = []
|
||||
|
||||
def _criterion(outputs, inputs):
|
||||
loss, chosen_rewards, rejected_rewards = self.actor_loss_fn(
|
||||
calc_masked_log_probs(
|
||||
outputs["logits"][0::2],
|
||||
inputs["input_ids"][0::2],
|
||||
inputs["loss_mask"][0::2][:, 1:],
|
||||
self.length_normalization,
|
||||
),
|
||||
calc_masked_log_probs(
|
||||
outputs["logits"][1::2],
|
||||
inputs["input_ids"][1::2],
|
||||
inputs["loss_mask"][1::2][:, 1:],
|
||||
self.length_normalization,
|
||||
),
|
||||
inputs["logprob_ref"][0::2] if inputs["logprob_ref"] is not None else None,
|
||||
inputs["logprob_ref"][1::2] if inputs["logprob_ref"] is not None else None,
|
||||
inputs["loss_mask"][0::2][:, 1:],
|
||||
inputs["loss_mask"][1::2][:, 1:],
|
||||
)
|
||||
rewards.append(chosen_rewards)
|
||||
rewards.append(rejected_rewards)
|
||||
return loss
|
||||
|
||||
outputs = self.booster.execute_pipeline(
|
||||
data_iter,
|
||||
self.model,
|
||||
criterion=_criterion,
|
||||
optimizer=self.optimizer,
|
||||
return_loss=True,
|
||||
)
|
||||
loss = outputs["loss"]
|
||||
if self.booster.plugin.stage_manager.is_last_stage():
|
||||
chosen_rewards, rejected_rewards = rewards[0], rewards[1]
|
||||
global_loss = all_reduce_mean(loss, self.plugin)
|
||||
chosen_rewards_mean = all_reduce_mean(chosen_rewards, self.plugin)
|
||||
rejected_rewards_mean = all_reduce_mean(rejected_rewards, self.plugin)
|
||||
if dist.get_rank() == dist.get_world_size() - 1:
|
||||
step_bar.set_postfix(
|
||||
{
|
||||
"eval/loss": global_loss.item(),
|
||||
"eval/lr": self.actor_scheduler.get_last_lr()[0],
|
||||
"eval/chosen_rewards": chosen_rewards.to(torch.float16).mean().item(),
|
||||
"eval/rejected_rewards": rejected_rewards.to(torch.float16).mean().item(),
|
||||
}
|
||||
)
|
||||
self.accumulative_meter.add(
|
||||
"chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item()
|
||||
)
|
||||
self.accumulative_meter.add(
|
||||
"rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item()
|
||||
)
|
||||
self.accumulative_meter.add("loss", global_loss.to(torch.float16).item())
|
||||
step_bar.update()
|
||||
if self.booster.plugin.stage_manager.is_last_stage():
|
||||
msg = "\nEvaluation Result:\n"
|
||||
for tag in ["loss", "chosen_rewards", "rejected_rewards"]:
|
||||
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
|
||||
if dist.get_rank() == dist.get_world_size() - 1:
|
||||
print(msg)
|
||||
else:
|
||||
step_bar = trange(
|
||||
len(self.eval_dataloader),
|
||||
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
|
||||
disable=not is_rank_0(),
|
||||
)
|
||||
with torch.no_grad():
|
||||
for i, batch in enumerate(self.eval_dataloader):
|
||||
batch = to_device(batch, self.device)
|
||||
(
|
||||
chosen_input_ids,
|
||||
chosen_attention_mask,
|
||||
chosen_loss_mask,
|
||||
reject_input_ids,
|
||||
reject_attention_mask,
|
||||
reject_loss_mask,
|
||||
) = (
|
||||
batch["chosen_input_ids"],
|
||||
batch["chosen_attention_mask"],
|
||||
batch["chosen_loss_mask"],
|
||||
batch["reject_input_ids"],
|
||||
batch["reject_attention_mask"],
|
||||
batch["reject_loss_mask"],
|
||||
)
|
||||
if not self.apply_loss_mask:
|
||||
chosen_loss_mask = chosen_loss_mask.fill_(1.0)
|
||||
reject_loss_mask = reject_loss_mask.fill_(1.0)
|
||||
|
||||
batch_size = chosen_input_ids.size()[0]
|
||||
|
||||
actor_all_logits = self.model(
|
||||
torch.cat([chosen_input_ids, reject_input_ids]),
|
||||
torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||
)["logits"]
|
||||
actor_chosen_logits = actor_all_logits[:batch_size]
|
||||
actor_reject_logits = actor_all_logits[batch_size:]
|
||||
|
||||
logprob_actor_chosen = calc_masked_log_probs(
|
||||
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
|
||||
)
|
||||
|
||||
logprob_actor_reject = calc_masked_log_probs(
|
||||
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
|
||||
)
|
||||
ref_all_logits = self.ref_model(
|
||||
torch.cat([chosen_input_ids, reject_input_ids]),
|
||||
torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||
)["logits"]
|
||||
ref_chosen_logits = ref_all_logits[:batch_size]
|
||||
ref_reject_logits = ref_all_logits[batch_size:]
|
||||
logprob_ref_chosen = calc_masked_log_probs(
|
||||
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
|
||||
)
|
||||
logprob_ref_reject = calc_masked_log_probs(
|
||||
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
|
||||
)
|
||||
|
||||
losses, chosen_rewards, rejected_rewards = self.actor_loss_fn(
|
||||
logprob_actor_chosen,
|
||||
logprob_actor_reject,
|
||||
logprob_ref_chosen if logprob_ref_chosen is not None else None,
|
||||
logprob_ref_reject if logprob_ref_reject is not None else None,
|
||||
chosen_loss_mask[:, 1:],
|
||||
reject_loss_mask[:, 1:],
|
||||
)
|
||||
reward_accuracies = (chosen_rewards > rejected_rewards).float().mean()
|
||||
loss = losses.mean()
|
||||
loss_mean = all_reduce_mean(tensor=loss)
|
||||
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards)
|
||||
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards)
|
||||
reward_accuracies_mean = all_reduce_mean(tensor=reward_accuracies)
|
||||
self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
|
||||
self.accumulative_meter.add(
|
||||
"rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item()
|
||||
)
|
||||
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
|
||||
self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
|
||||
self.accumulative_meter.add(
|
||||
"margin", (chosen_rewards_mean - rejected_rewards_mean).to(torch.float16).mean().item()
|
||||
)
|
||||
step_bar.set_postfix(
|
||||
{
|
||||
"eval/loss": self.accumulative_meter.get("loss"),
|
||||
"eval/chosen_rewards": self.accumulative_meter.get("chosen_rewards"),
|
||||
"eval/rejected_rewards": self.accumulative_meter.get("rejected_rewards"),
|
||||
"eval/accuracy": self.accumulative_meter.get("accuracy"),
|
||||
}
|
||||
)
|
||||
step_bar.update()
|
||||
|
||||
msg = "\nEvaluation Result:\n"
|
||||
for tag in ["loss", "chosen_rewards", "rejected_rewards", "accuracy", "margin"]:
|
||||
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
|
||||
self.coordinator.print_on_master(msg)
|
||||
if self.save_dir is not None:
|
||||
os.makedirs(self.save_dir, exist_ok=True)
|
||||
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
|
||||
f.write(msg)
|
||||
step_bar.close()
|
||||
|
|
|
@ -73,8 +73,7 @@ def main():
|
|||
"--conversation_template_config",
|
||||
type=str,
|
||||
default="conversation_template_config",
|
||||
help="Path \
|
||||
to save conversation template config files.",
|
||||
help="Path to save conversation template config files.",
|
||||
)
|
||||
parser.add_argument("--data_cache_dir", type=str, default="cache", help="Data cache directory")
|
||||
parser.add_argument(
|
||||
|
|
|
@ -13,7 +13,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
|||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
|
@ -29,8 +29,6 @@ def train(args):
|
|||
# check lora compatibility
|
||||
if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
|
||||
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
||||
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
|
||||
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
|
||||
|
||||
# ==============================
|
||||
# Initialize Distributed Training
|
||||
|
@ -46,7 +44,7 @@ def train(args):
|
|||
Default torch ddp plugin without any acceleration, for
|
||||
debugging purpose acceleration, for debugging purpose
|
||||
"""
|
||||
plugin = TorchDDPPlugin(find_unused_parameters=True)
|
||||
plugin = TorchDDPPlugin(find_unused_parameters=not args.grad_checkpoint)
|
||||
elif args.plugin == "gemini":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision,
|
||||
|
@ -56,14 +54,6 @@ def train(args):
|
|||
enable_gradient_accumulation=True,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision,
|
||||
placement_policy="auto",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
)
|
||||
elif args.plugin == "zero2":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2,
|
||||
|
@ -92,20 +82,24 @@ def train(args):
|
|||
parallel_output=False,
|
||||
max_norm=args.grad_clip,
|
||||
precision=args.mixed_precision,
|
||||
microbatch_size=args.microbatch_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
|
||||
booster = Booster(plugin=plugin)
|
||||
ref_booster = Booster(plugin=plugin)
|
||||
|
||||
# ======================================================
|
||||
# Initialize Model, Objective, Optimizer and LR Scheduler
|
||||
# ======================================================
|
||||
# Temp Fix: Disable lazy init due to version conflict
|
||||
# init_ctx = (
|
||||
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
|
||||
# )
|
||||
ref_plugin = HybridParallelPlugin(
|
||||
tp_size=args.ref_tp,
|
||||
pp_size=1,
|
||||
zero_stage=args.zero_stage,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
|
||||
parallel_output=False,
|
||||
max_norm=args.grad_clip,
|
||||
precision=args.mixed_precision,
|
||||
)
|
||||
ref_booster = Booster(plugin=ref_plugin)
|
||||
|
||||
init_ctx = nullcontext()
|
||||
with init_ctx:
|
||||
|
@ -130,6 +124,7 @@ def train(args):
|
|||
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain)
|
||||
else:
|
||||
ref_model = None
|
||||
|
||||
if args.lora_config is not None:
|
||||
model = convert_to_lora_module(model, lora_config=lora_config)
|
||||
for name, module in model.named_modules():
|
||||
|
@ -139,7 +134,9 @@ def train(args):
|
|||
disable_dropout(ref_model)
|
||||
|
||||
if args.grad_checkpoint:
|
||||
# Note, for some models, lora may not be compatible with gradient checkpointing
|
||||
# Make sure gradient checkpointing can be activated.
|
||||
model.train()
|
||||
# Note, for some models, lora may not be compatible with gradient checkpointing.
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||
|
||||
|
@ -169,7 +166,7 @@ def train(args):
|
|||
adamw_mode=True,
|
||||
)
|
||||
|
||||
# configure dataset
|
||||
# Configure dataset
|
||||
coordinator.print_on_master(f"Load dataset: {args.dataset}")
|
||||
mode_map = {"train": "train", "valid": "validation", "test": "test"}
|
||||
train_dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train", mode_map=mode_map)
|
||||
|
@ -213,14 +210,15 @@ def train(args):
|
|||
|
||||
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
|
||||
torch.set_default_dtype(default_dtype)
|
||||
|
||||
model, optim, _, train_dataloader, lr_scheduler = booster.boost(
|
||||
model=model,
|
||||
optimizer=optim,
|
||||
lr_scheduler=lr_scheduler,
|
||||
dataloader=train_dataloader,
|
||||
)
|
||||
if ref_model is not None:
|
||||
ref_model, _, _, _, _ = ref_booster.boost(model=ref_model, dataloader=train_dataloader)
|
||||
ref_model, _, _, _, _ = ref_booster.boost(model=ref_model)
|
||||
|
||||
torch.set_default_dtype(torch.float)
|
||||
|
||||
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
||||
|
@ -312,7 +310,7 @@ if __name__ == "__main__":
|
|||
"--plugin",
|
||||
type=str,
|
||||
default="gemini",
|
||||
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
|
||||
choices=["gemini", "zero2", "zero2_cpu", "3d", "ddp"],
|
||||
help="Choose which plugin to use",
|
||||
)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
|
||||
|
@ -342,22 +340,35 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
|
||||
parser.add_argument("--max_epochs", type=int, default=3)
|
||||
parser.add_argument("--batch_size", type=int, default=4)
|
||||
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
|
||||
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
|
||||
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
|
||||
parser.add_argument("--lr", type=float, default=5e-6)
|
||||
parser.add_argument("--accumulation_steps", type=int, default=1)
|
||||
parser.add_argument("--log_dir", default=None, type=str)
|
||||
parser.add_argument("--use_wandb", default=False, action="store_true")
|
||||
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
|
||||
parser.add_argument("--use_flash_attn", default=False, action="store_true")
|
||||
parser.add_argument(
|
||||
"--microbatch_size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Micro batch size for PP training. To activate PP training for DPO-like algorithm, you must keep size even and the size should be equal or greater than 2.",
|
||||
)
|
||||
# Parameter for reference model
|
||||
parser.add_argument(
|
||||
"--disable_reference_model",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Disable the reference model (enabled by default)",
|
||||
)
|
||||
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
|
||||
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
|
||||
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
|
||||
parser.add_argument("--lr", type=float, default=5e-6)
|
||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||
parser.add_argument("--log_dir", default=None, type=str)
|
||||
parser.add_argument("--use_wandb", default=False, action="store_true")
|
||||
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
|
||||
parser.add_argument("--use_flash_attn", default=False, action="store_true")
|
||||
parser.add_argument(
|
||||
"--ref_tp",
|
||||
type=int,
|
||||
default=1,
|
||||
help="TP size for reference model; used only when reference model is too large.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# fool proof hyperparameter setup
|
||||
|
|
|
@ -68,7 +68,7 @@ def train(args):
|
|||
Default torch ddp plugin without any acceleration, for
|
||||
debugging purpose acceleration, for debugging purpose
|
||||
"""
|
||||
plugin = TorchDDPPlugin(find_unused_parameters=True if args.grad_checkpoint is False else False)
|
||||
plugin = TorchDDPPlugin(find_unused_parameters=not args.grad_checkpoint)
|
||||
elif args.plugin == "gemini":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision,
|
||||
|
|
|
@ -4,7 +4,7 @@ BASE_TEMP_DIR=$BASE_DIR/temp
|
|||
EXAMPLES_DIR=$BASE_DIR/examples
|
||||
TEST_DATA_DIR=$BASE_DIR/tests/test_data
|
||||
DATA_SAVE_PATH=$BASE_TEMP_DIR/tests
|
||||
CONFIG_DIR=$BASE_DIR/config
|
||||
CONFIG_DIR=$BASE_DIR/conversation_template
|
||||
|
||||
# MODELS=("colossal-llama2" "llama2" "mistral" "chatGLM2" "chatGLM3" "deepseek" "Yi" "baichuan") # for local test
|
||||
MODELS=("colossal-llama2" "llama2" "chatGLM2" "chatGLM3" "deepseek" "Yi")
|
||||
|
@ -39,23 +39,23 @@ get_pretrain() {
|
|||
get_conversation_template_config() {
|
||||
local model=$1
|
||||
if [[ $model == "colossal-llama2" ]]; then
|
||||
echo "$CONFIG_DIR/conversation_template/colossal-llama2.json"
|
||||
echo "$CONFIG_DIR/colossal-llama2.json"
|
||||
elif [[ $model == "llama2" ]]; then
|
||||
echo "$CONFIG_DIR/conversation_template/llama2.json"
|
||||
echo "$CONFIG_DIR/llama2.json"
|
||||
elif [[ $model == "deepseek" ]]; then
|
||||
echo "$CONFIG_DIR/conversation_template/deepseek-ai_DeepSeek-V2-Lite.json"
|
||||
echo "$CONFIG_DIR/deepseek-ai_DeepSeek-V2-Lite.json"
|
||||
elif [[ $model == "mistral" ]]; then
|
||||
echo "$CONFIG_DIR/conversation_template/mistralai_Mixtral-8x7B-Instruct-v0.1.json"
|
||||
echo "$CONFIG_DIR/mistralai_Mixtral-8x7B-Instruct-v0.1.json"
|
||||
elif [[ $model == "chatGLM2" ]]; then
|
||||
echo "$CONFIG_DIR/conversation_template/THUDM_chatglm2-6b.json"
|
||||
echo "$CONFIG_DIR/THUDM_chatglm2-6b.json"
|
||||
elif [[ $model == "chatGLM3" ]]; then
|
||||
echo "$CONFIG_DIR/conversation_template/THUDM_chatglm3-6b.json"
|
||||
echo "$CONFIG_DIR/THUDM_chatglm3-6b.json"
|
||||
elif [[ $model == "phi" ]]; then
|
||||
echo "$CONFIG_DIR/conversation_template/microsoft_phi-2.json"
|
||||
echo "$CONFIG_DIR/microsoft_phi-2.json"
|
||||
elif [[ $model == "Yi" ]]; then
|
||||
echo "$CONFIG_DIR/conversation_template/01-ai_Yi-1.5-9B-Chat.json"
|
||||
echo "$CONFIG_DIR/01-ai_Yi-1.5-9B-Chat.json"
|
||||
elif [[ $model == "baichuan" ]]; then
|
||||
echo "$CONFIG_DIR/conversation_template/baichuan-inc_Baichuan2-13B-Chat.json"
|
||||
echo "$CONFIG_DIR/baichuan-inc_Baichuan2-13B-Chat.json"
|
||||
else
|
||||
echo "Unknown model $model"
|
||||
exit 1
|
||||
|
@ -71,6 +71,7 @@ for model in ${MODELS[@]}; do
|
|||
rm -rf $SAVE_DIR/arrow
|
||||
pretrain=$(get_pretrain $model)
|
||||
conversation_template_config=$(get_conversation_template_config $model)
|
||||
echo $conversation_template_config
|
||||
python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py --type sft --data_input_dirs $TEST_DATA_DIR/sft \
|
||||
--tokenizer_dir $pretrain \
|
||||
--conversation_template_config $conversation_template_config \
|
||||
|
|
|
@ -271,6 +271,7 @@ class LlamaPipelineForwards:
|
|||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
|
|
Loading…
Reference in New Issue