[Coati] Train DPO using PP (#6054)

* update dpo

* remove unsupport plugin

* update msg

* update dpo

* remove unsupport plugin

* update msg

* update template

* update dataset

* add pp for dpo

* update dpo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add dpo fn

* update dpo

* update dpo

* update dpo

* update dpo

* minor update

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update loss

* update help

* polish code

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/6118/merge
Tong Li 1 month ago committed by GitHub
parent dc2cdaf3e8
commit 4c8e85ee0d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -153,10 +153,11 @@ class DpoLoss(nn.Module):
else: else:
# If no reference model is provided # If no reference model is provided
ref_logratios = 0.0 ref_logratios = 0.0
pi_logratios = logprob_actor_chosen.sum(-1) - logprob_actor_reject.sum(-1) pi_logratios = logprob_actor_chosen.sum(-1) - logprob_actor_reject.sum(-1)
logits = pi_logratios - ref_logratios - self.gamma / self.beta logits = pi_logratios - ref_logratios - self.gamma / self.beta
losses = -torch.nn.functional.logsigmoid(self.beta * logits) losses = -torch.nn.functional.logsigmoid(self.beta * logits)
loss = losses.mean()
# Calculate rewards for logging # Calculate rewards for logging
if logprob_ref_chosen is not None: if logprob_ref_chosen is not None:
chosen_rewards = self.beta * (logprob_actor_chosen.sum(-1) - logprob_ref_chosen.sum(-1)).detach() chosen_rewards = self.beta * (logprob_actor_chosen.sum(-1) - logprob_ref_chosen.sum(-1)).detach()
@ -167,7 +168,7 @@ class DpoLoss(nn.Module):
else: else:
rejected_rewards = self.beta * logprob_actor_reject.sum(-1).detach() rejected_rewards = self.beta * logprob_actor_reject.sum(-1).detach()
return losses, chosen_rewards, rejected_rewards return loss, chosen_rewards, rejected_rewards
class LogSigLoss(nn.Module): class LogSigLoss(nn.Module):

@ -50,8 +50,8 @@ def _log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.
torch.Tensor: The log probabilities corresponding to the labels. torch.Tensor: The log probabilities corresponding to the labels.
""" """
log_probs = F.log_softmax(logits, dim=-1) log_probs = F.log_softmax(logits, dim=-1)
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)) per_label_logps = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
return log_probs_labels.squeeze(-1) return per_label_logps.squeeze(-1)
def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor: def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor:

@ -6,6 +6,7 @@ import os
from typing import Any, Optional from typing import Any, Optional
import torch import torch
import torch.distributed as dist
from coati.models.loss import DpoLoss from coati.models.loss import DpoLoss
from coati.models.utils import calc_masked_log_probs from coati.models.utils import calc_masked_log_probs
from coati.trainer.utils import all_reduce_mean from coati.trainer.utils import all_reduce_mean
@ -13,10 +14,11 @@ from coati.utils import AccumulativeMeanMeter, save_checkpoint
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm import trange from tqdm import tqdm, trange
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from colossalai.booster import Booster, Plugin from colossalai.booster import Booster, Plugin
from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
@ -96,18 +98,25 @@ class DPOTrainer(SLTrainer):
self.train_dataloader = train_preference_dataloader self.train_dataloader = train_preference_dataloader
self.eval_dataloader = eval_preference_dataloader self.eval_dataloader = eval_preference_dataloader
self.writer = None self.writer = None
if use_wandb and is_rank_0():
init_criterion = (
dist.get_rank() == dist.get_world_size() - 1
if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1
else is_rank_0()
)
if use_wandb and init_criterion:
assert log_dir is not None, "log_dir must be provided when use_wandb is True" assert log_dir is not None, "log_dir must be provided when use_wandb is True"
import wandb import wandb
self.wandb_run = wandb.init(project="Coati-dpo", sync_tensorboard=True) self.wandb_run = wandb.init(project="Coati-dpo", sync_tensorboard=True)
if log_dir is not None and is_rank_0(): if log_dir is not None and init_criterion:
import os import os
import time import time
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
log_dir = os.path.join(log_dir, "dpo") log_dir = os.path.join(log_dir, "DPO")
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())) log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
self.writer = SummaryWriter(log_dir=log_dir) self.writer = SummaryWriter(log_dir=log_dir)
@ -117,166 +126,147 @@ class DPOTrainer(SLTrainer):
epoch int: the number of current epoch epoch int: the number of current epoch
""" """
self.model.train() self.model.train()
self.accumulative_meter.reset() if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1:
step_bar = trange( step_bar = tqdm(
len(self.train_dataloader) // self.accumulation_steps, range(len(self.train_dataloader)),
desc=f"Epoch {epoch + 1}/{self.max_epochs}", desc="Step",
disable=not is_rank_0(), disable=not (dist.get_rank() == dist.get_world_size() - 1),
)
for i, batch in enumerate(self.train_dataloader):
batch = to_device(batch, self.device)
(
chosen_input_ids,
chosen_attention_mask,
chosen_loss_mask,
reject_input_ids,
reject_attention_mask,
reject_loss_mask,
) = (
batch["chosen_input_ids"],
batch["chosen_attention_mask"],
batch["chosen_loss_mask"],
batch["reject_input_ids"],
batch["reject_attention_mask"],
batch["reject_loss_mask"],
)
if not self.apply_loss_mask:
chosen_loss_mask = chosen_loss_mask.fill_(1.0)
reject_loss_mask = reject_loss_mask.fill_(1.0)
batch_size = chosen_input_ids.size()[0]
actor_all_logits = self.model(
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
)["logits"]
actor_chosen_logits = actor_all_logits[:batch_size]
actor_reject_logits = actor_all_logits[batch_size:]
logprob_actor_chosen = calc_masked_log_probs(
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
)
logprob_actor_reject = calc_masked_log_probs(
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
) )
for i, batch in enumerate(self.train_dataloader):
if self.ref_model is not None: batch = to_device(batch, self.device)
self.ref_model.eval() (
with torch.no_grad(): chosen_input_ids,
ref_all_logits = self.ref_model( chosen_attention_mask,
input_ids=torch.cat([chosen_input_ids, reject_input_ids]), chosen_loss_mask,
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]), reject_input_ids,
)["logits"] reject_attention_mask,
ref_chosen_logits = ref_all_logits[:batch_size] reject_loss_mask,
ref_reject_logits = ref_all_logits[batch_size:] ) = (
logprob_ref_chosen = calc_masked_log_probs( batch["chosen_input_ids"],
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization batch["chosen_attention_mask"],
) batch["chosen_loss_mask"],
logprob_ref_reject = calc_masked_log_probs( batch["reject_input_ids"],
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization batch["reject_attention_mask"],
batch["reject_loss_mask"],
)
batch_size = chosen_input_ids.size()[0]
# Calculate logits from reference model.
if self.ref_model is not None:
self.ref_model.eval()
with torch.no_grad():
ref_all_logits = self.ref_model(
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
)["logits"]
ref_chosen_logits = ref_all_logits[:batch_size]
ref_reject_logits = ref_all_logits[batch_size:]
logprob_ref_chosen = calc_masked_log_probs(
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
)
logprob_ref_reject = calc_masked_log_probs(
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
)
else:
logprob_ref_chosen = None
logprob_ref_reject = None
# Merge chosen and reject
inputs_ids = torch.stack([item for tup in zip(chosen_input_ids, reject_input_ids) for item in tup])
attention_mask = torch.stack(
[item for tup in zip(chosen_attention_mask, reject_attention_mask) for item in tup]
)
loss_mask = torch.stack([item for tup in zip(chosen_loss_mask, reject_loss_mask) for item in tup])
logprob_ref = torch.stack([item for tup in zip(logprob_ref_chosen, logprob_ref_reject) for item in tup])
data_iter = iter(
[
{
"input_ids": inputs_ids,
"attention_mask": attention_mask,
"loss_mask": loss_mask,
"logprob_ref": logprob_ref,
}
]
)
rewards = []
def _criterion(outputs, inputs):
loss, chosen_rewards, rejected_rewards = self.actor_loss_fn(
calc_masked_log_probs(
outputs["logits"][0::2],
inputs["input_ids"][0::2],
inputs["loss_mask"][0::2][:, 1:],
self.length_normalization,
),
calc_masked_log_probs(
outputs["logits"][1::2],
inputs["input_ids"][1::2],
inputs["loss_mask"][1::2][:, 1:],
self.length_normalization,
),
inputs["logprob_ref"][0::2] if inputs["logprob_ref"] is not None else None,
inputs["logprob_ref"][1::2] if inputs["logprob_ref"] is not None else None,
inputs["loss_mask"][0::2][:, 1:],
inputs["loss_mask"][1::2][:, 1:],
) )
else: rewards.append(chosen_rewards)
logprob_ref_chosen = None rewards.append(rejected_rewards)
logprob_ref_reject = None return loss
losses, chosen_rewards, rejected_rewards = self.actor_loss_fn( outputs = self.booster.execute_pipeline(
logprob_actor_chosen, data_iter,
logprob_actor_reject, self.model,
logprob_ref_chosen if logprob_ref_chosen is not None else None, criterion=_criterion,
logprob_ref_reject if logprob_ref_reject is not None else None, optimizer=self.optimizer,
chosen_loss_mask[:, 1:], return_loss=True,
reject_loss_mask[:, 1:], )
) loss = outputs["loss"]
reward_accuracies = (chosen_rewards > rejected_rewards).float().mean() if self.booster.plugin.stage_manager.is_last_stage():
chosen_rewards, rejected_rewards = rewards[0], rewards[1]
# DPO Loss global_loss = all_reduce_mean(loss, self.plugin)
loss = losses.mean() if dist.get_rank() == dist.get_world_size() - 1:
step_bar.set_postfix(
{
"train/loss": global_loss.item(),
"train/lr": self.actor_scheduler.get_last_lr()[0],
"train/chosen_rewards": chosen_rewards.to(torch.float16).mean().item(),
"train/rejected_rewards": rejected_rewards.to(torch.float16).mean().item(),
}
)
step_bar.update()
self.accumulative_meter.add("loss", global_loss.item())
self.accumulative_meter.add("chosen_rewards", chosen_rewards.to(torch.float16).mean().item())
self.accumulative_meter.add(
"rejected_rewards", rejected_rewards.to(torch.float16).mean().item()
)
if self.writer is not None:
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), i)
self.writer.add_scalar(
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), i
)
self.writer.add_scalar(
"train/rejected_rewards",
self.accumulative_meter.get("rejected_rewards"),
i,
)
self.writer.add_scalar(
"train/margin",
self.accumulative_meter.get("chosen_rewards")
- self.accumulative_meter.get("rejected_rewards"),
i,
)
self.booster.backward(loss=loss, optimizer=self.optimizer)
if self.num_train_step % self.accumulation_steps == self.accumulation_steps - 1:
self.optimizer.step() self.optimizer.step()
self.optimizer.zero_grad() self.optimizer.zero_grad()
self.actor_scheduler.step() self.actor_scheduler.step()
else:
# sync self.accumulative_meter.reset()
loss_mean = all_reduce_mean(tensor=loss) step_bar = trange(
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards) len(self.train_dataloader) // self.accumulation_steps,
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards) desc=f"Epoch {epoch + 1}/{self.max_epochs}",
reward_accuracies_mean = all_reduce_mean(tensor=reward_accuracies) disable=not is_rank_0(),
self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item()) )
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item()) for i, batch in enumerate(self.train_dataloader):
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
if i % self.accumulation_steps == self.accumulation_steps - 1:
self.num_train_step += 1
step_bar.update()
# logging
if self.writer and is_rank_0():
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
self.writer.add_scalar(
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step
)
self.writer.add_scalar(
"train/rejected_rewards",
self.accumulative_meter.get("rejected_rewards"),
self.num_train_step,
)
self.writer.add_scalar(
"train/margin",
self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"),
self.num_train_step,
)
self.writer.add_scalar(
"train/accuracy",
self.accumulative_meter.get("accuracy"),
self.num_train_step,
)
self.accumulative_meter.reset()
if self.save_dir is not None and (self.num_train_step + 1) % self.save_interval == 0:
# save checkpoint
self.coordinator.print_on_master("\nStart saving model checkpoint with running states")
save_checkpoint(
save_dir=self.save_dir,
booster=self.booster,
model=self.model,
optimizer=self.optimizer,
lr_scheduler=self.actor_scheduler,
epoch=epoch,
step=i + 1,
batch_size=batch_size,
coordinator=self.coordinator,
)
self.coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}"
)
step_bar.close()
def _eval(self, epoch: int):
"""
Args:
epoch int: the number of current epoch
"""
if self.eval_dataloader is None:
self.coordinator.print_on_master("No eval dataloader is provided, skip evaluation")
return
self.model.eval()
self.ref_model.eval()
self.coordinator.print_on_master("\nStart evaluation...")
step_bar = trange(
len(self.eval_dataloader),
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
disable=not is_rank_0(),
)
self.accumulative_meter.reset()
with torch.no_grad():
for i, batch in enumerate(self.eval_dataloader):
batch = to_device(batch, self.device) batch = to_device(batch, self.device)
( (
chosen_input_ids, chosen_input_ids,
@ -300,12 +290,11 @@ class DPOTrainer(SLTrainer):
batch_size = chosen_input_ids.size()[0] batch_size = chosen_input_ids.size()[0]
actor_all_logits = self.model( actor_all_logits = self.model(
torch.cat([chosen_input_ids, reject_input_ids]), input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
torch.cat([chosen_attention_mask, reject_attention_mask]), attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
)["logits"] )["logits"]
actor_chosen_logits = actor_all_logits[:batch_size] actor_chosen_logits = actor_all_logits[:batch_size]
actor_reject_logits = actor_all_logits[batch_size:] actor_reject_logits = actor_all_logits[batch_size:]
logprob_actor_chosen = calc_masked_log_probs( logprob_actor_chosen = calc_masked_log_probs(
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
) )
@ -314,22 +303,26 @@ class DPOTrainer(SLTrainer):
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
) )
self.ref_model.eval() if self.ref_model is not None:
self.ref_model.eval()
ref_all_logits = self.ref_model( with torch.no_grad():
torch.cat([chosen_input_ids, reject_input_ids]), ref_all_logits = self.ref_model(
torch.cat([chosen_attention_mask, reject_attention_mask]), input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
)["logits"] attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
ref_chosen_logits = ref_all_logits[:batch_size] )["logits"]
ref_reject_logits = ref_all_logits[batch_size:] ref_chosen_logits = ref_all_logits[:batch_size]
logprob_ref_chosen = calc_masked_log_probs( ref_reject_logits = ref_all_logits[batch_size:]
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization logprob_ref_chosen = calc_masked_log_probs(
) ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
logprob_ref_reject = calc_masked_log_probs( )
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization logprob_ref_reject = calc_masked_log_probs(
) ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
)
losses, chosen_rewards, rejected_rewards = self.actor_loss_fn( else:
logprob_ref_chosen = None
logprob_ref_reject = None
loss, chosen_rewards, rejected_rewards = self.actor_loss_fn(
logprob_actor_chosen, logprob_actor_chosen,
logprob_actor_reject, logprob_actor_reject,
logprob_ref_chosen if logprob_ref_chosen is not None else None, logprob_ref_chosen if logprob_ref_chosen is not None else None,
@ -338,7 +331,9 @@ class DPOTrainer(SLTrainer):
reject_loss_mask[:, 1:], reject_loss_mask[:, 1:],
) )
reward_accuracies = (chosen_rewards > rejected_rewards).float().mean() reward_accuracies = (chosen_rewards > rejected_rewards).float().mean()
loss = losses.mean()
self.booster.backward(loss=loss, optimizer=self.optimizer)
# sync
loss_mean = all_reduce_mean(tensor=loss) loss_mean = all_reduce_mean(tensor=loss)
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards) chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards)
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards) rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards)
@ -347,16 +342,301 @@ class DPOTrainer(SLTrainer):
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item()) self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item()) self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item()) self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
self.accumulative_meter.add(
"margin", (chosen_rewards_mean - rejected_rewards_mean).to(torch.float16).mean().item() if (i + 1) % self.accumulation_steps == 0:
self.optimizer.step()
self.optimizer.zero_grad()
self.actor_scheduler.step()
step_bar.set_postfix(
{
"train/loss": self.accumulative_meter.get("loss"),
"train/chosen_rewards": self.accumulative_meter.get("chosen_rewards"),
"train/rejected_rewards": self.accumulative_meter.get("rejected_rewards"),
"train/accuracy": self.accumulative_meter.get("accuracy"),
}
)
step_bar.update()
if self.writer and is_rank_0():
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
self.writer.add_scalar(
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step
)
self.writer.add_scalar(
"train/rejected_rewards",
self.accumulative_meter.get("rejected_rewards"),
self.num_train_step,
)
self.writer.add_scalar(
"train/margin",
self.accumulative_meter.get("chosen_rewards")
- self.accumulative_meter.get("rejected_rewards"),
self.num_train_step,
)
self.writer.add_scalar(
"train/accuracy",
self.accumulative_meter.get("accuracy"),
self.num_train_step,
)
self.num_train_step += 1
self.accumulative_meter.reset()
if self.save_dir is not None and self.num_train_step > 0 and self.num_train_step % self.save_interval == 0:
# save checkpoint
self.coordinator.print_on_master("\nStart saving model checkpoint with running states")
save_checkpoint(
save_dir=self.save_dir,
booster=self.booster,
model=self.model,
optimizer=self.optimizer,
lr_scheduler=self.actor_scheduler,
epoch=epoch,
step=self.num_train_step,
batch_size=batch_size,
coordinator=self.coordinator,
) )
step_bar.update() self.coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}"
msg = "Evaluation Result:\n" )
for tag in ["loss", "chosen_rewards", "rejected_rewards", "accuracy", "margin"]:
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n" step_bar.close()
self.coordinator.print_on_master(msg)
os.makedirs(self.save_dir, exist_ok=True) def _eval(self, epoch: int):
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f: """
f.write(msg) Args:
epoch int: the number of current epoch
"""
if self.eval_dataloader is None:
self.coordinator.print_on_master("No eval dataloader is provided, skip evaluation")
return
self.model.eval()
self.ref_model.eval()
self.accumulative_meter.reset()
self.coordinator.print_on_master("\nStart evaluation...")
if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1:
step_bar = tqdm(
range(len(self.eval_dataloader)),
desc="Step",
disable=not (dist.get_rank() == dist.get_world_size() - 1),
)
with torch.no_grad():
for _, batch in enumerate(self.eval_dataloader):
batch = to_device(batch, self.device)
(
chosen_input_ids,
chosen_attention_mask,
chosen_loss_mask,
reject_input_ids,
reject_attention_mask,
reject_loss_mask,
) = (
batch["chosen_input_ids"],
batch["chosen_attention_mask"],
batch["chosen_loss_mask"],
batch["reject_input_ids"],
batch["reject_attention_mask"],
batch["reject_loss_mask"],
)
batch_size = chosen_input_ids.size()[0]
# Calculate logits from reference model.
if self.ref_model is not None:
self.ref_model.eval()
with torch.no_grad():
ref_all_logits = self.ref_model(
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
)["logits"]
ref_chosen_logits = ref_all_logits[:batch_size]
ref_reject_logits = ref_all_logits[batch_size:]
logprob_ref_chosen = calc_masked_log_probs(
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
)
logprob_ref_reject = calc_masked_log_probs(
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
)
else:
logprob_ref_chosen = None
logprob_ref_reject = None
# Merge chosen and reject
inputs_ids = torch.stack([item for tup in zip(chosen_input_ids, reject_input_ids) for item in tup])
attention_mask = torch.stack(
[item for tup in zip(chosen_attention_mask, reject_attention_mask) for item in tup]
)
loss_mask = torch.stack([item for tup in zip(chosen_loss_mask, reject_loss_mask) for item in tup])
logprob_ref = torch.stack(
[item for tup in zip(logprob_ref_chosen, logprob_ref_reject) for item in tup]
)
data_iter = iter(
[
{
"input_ids": inputs_ids,
"attention_mask": attention_mask,
"loss_mask": loss_mask,
"logprob_ref": logprob_ref,
}
]
)
rewards = []
def _criterion(outputs, inputs):
loss, chosen_rewards, rejected_rewards = self.actor_loss_fn(
calc_masked_log_probs(
outputs["logits"][0::2],
inputs["input_ids"][0::2],
inputs["loss_mask"][0::2][:, 1:],
self.length_normalization,
),
calc_masked_log_probs(
outputs["logits"][1::2],
inputs["input_ids"][1::2],
inputs["loss_mask"][1::2][:, 1:],
self.length_normalization,
),
inputs["logprob_ref"][0::2] if inputs["logprob_ref"] is not None else None,
inputs["logprob_ref"][1::2] if inputs["logprob_ref"] is not None else None,
inputs["loss_mask"][0::2][:, 1:],
inputs["loss_mask"][1::2][:, 1:],
)
rewards.append(chosen_rewards)
rewards.append(rejected_rewards)
return loss
outputs = self.booster.execute_pipeline(
data_iter,
self.model,
criterion=_criterion,
optimizer=self.optimizer,
return_loss=True,
)
loss = outputs["loss"]
if self.booster.plugin.stage_manager.is_last_stage():
chosen_rewards, rejected_rewards = rewards[0], rewards[1]
global_loss = all_reduce_mean(loss, self.plugin)
chosen_rewards_mean = all_reduce_mean(chosen_rewards, self.plugin)
rejected_rewards_mean = all_reduce_mean(rejected_rewards, self.plugin)
if dist.get_rank() == dist.get_world_size() - 1:
step_bar.set_postfix(
{
"eval/loss": global_loss.item(),
"eval/lr": self.actor_scheduler.get_last_lr()[0],
"eval/chosen_rewards": chosen_rewards.to(torch.float16).mean().item(),
"eval/rejected_rewards": rejected_rewards.to(torch.float16).mean().item(),
}
)
self.accumulative_meter.add(
"chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item()
)
self.accumulative_meter.add(
"rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item()
)
self.accumulative_meter.add("loss", global_loss.to(torch.float16).item())
step_bar.update()
if self.booster.plugin.stage_manager.is_last_stage():
msg = "\nEvaluation Result:\n"
for tag in ["loss", "chosen_rewards", "rejected_rewards"]:
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
if dist.get_rank() == dist.get_world_size() - 1:
print(msg)
else:
step_bar = trange(
len(self.eval_dataloader),
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
disable=not is_rank_0(),
)
with torch.no_grad():
for i, batch in enumerate(self.eval_dataloader):
batch = to_device(batch, self.device)
(
chosen_input_ids,
chosen_attention_mask,
chosen_loss_mask,
reject_input_ids,
reject_attention_mask,
reject_loss_mask,
) = (
batch["chosen_input_ids"],
batch["chosen_attention_mask"],
batch["chosen_loss_mask"],
batch["reject_input_ids"],
batch["reject_attention_mask"],
batch["reject_loss_mask"],
)
if not self.apply_loss_mask:
chosen_loss_mask = chosen_loss_mask.fill_(1.0)
reject_loss_mask = reject_loss_mask.fill_(1.0)
batch_size = chosen_input_ids.size()[0]
actor_all_logits = self.model(
torch.cat([chosen_input_ids, reject_input_ids]),
torch.cat([chosen_attention_mask, reject_attention_mask]),
)["logits"]
actor_chosen_logits = actor_all_logits[:batch_size]
actor_reject_logits = actor_all_logits[batch_size:]
logprob_actor_chosen = calc_masked_log_probs(
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
)
logprob_actor_reject = calc_masked_log_probs(
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
)
ref_all_logits = self.ref_model(
torch.cat([chosen_input_ids, reject_input_ids]),
torch.cat([chosen_attention_mask, reject_attention_mask]),
)["logits"]
ref_chosen_logits = ref_all_logits[:batch_size]
ref_reject_logits = ref_all_logits[batch_size:]
logprob_ref_chosen = calc_masked_log_probs(
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
)
logprob_ref_reject = calc_masked_log_probs(
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
)
losses, chosen_rewards, rejected_rewards = self.actor_loss_fn(
logprob_actor_chosen,
logprob_actor_reject,
logprob_ref_chosen if logprob_ref_chosen is not None else None,
logprob_ref_reject if logprob_ref_reject is not None else None,
chosen_loss_mask[:, 1:],
reject_loss_mask[:, 1:],
)
reward_accuracies = (chosen_rewards > rejected_rewards).float().mean()
loss = losses.mean()
loss_mean = all_reduce_mean(tensor=loss)
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards)
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards)
reward_accuracies_mean = all_reduce_mean(tensor=reward_accuracies)
self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add(
"rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item()
)
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
self.accumulative_meter.add(
"margin", (chosen_rewards_mean - rejected_rewards_mean).to(torch.float16).mean().item()
)
step_bar.set_postfix(
{
"eval/loss": self.accumulative_meter.get("loss"),
"eval/chosen_rewards": self.accumulative_meter.get("chosen_rewards"),
"eval/rejected_rewards": self.accumulative_meter.get("rejected_rewards"),
"eval/accuracy": self.accumulative_meter.get("accuracy"),
}
)
step_bar.update()
msg = "\nEvaluation Result:\n"
for tag in ["loss", "chosen_rewards", "rejected_rewards", "accuracy", "margin"]:
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
self.coordinator.print_on_master(msg)
if self.save_dir is not None:
os.makedirs(self.save_dir, exist_ok=True)
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
f.write(msg)
step_bar.close() step_bar.close()

@ -73,8 +73,7 @@ def main():
"--conversation_template_config", "--conversation_template_config",
type=str, type=str,
default="conversation_template_config", default="conversation_template_config",
help="Path \ help="Path to save conversation template config files.",
to save conversation template config files.",
) )
parser.add_argument("--data_cache_dir", type=str, default="cache", help="Data cache directory") parser.add_argument("--data_cache_dir", type=str, default="cache", help="Data cache directory")
parser.add_argument( parser.add_argument(

@ -13,7 +13,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
import colossalai import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
@ -29,8 +29,6 @@ def train(args):
# check lora compatibility # check lora compatibility
if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0: if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin") raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
# ============================== # ==============================
# Initialize Distributed Training # Initialize Distributed Training
@ -46,7 +44,7 @@ def train(args):
Default torch ddp plugin without any acceleration, for Default torch ddp plugin without any acceleration, for
debugging purpose acceleration, for debugging purpose debugging purpose acceleration, for debugging purpose
""" """
plugin = TorchDDPPlugin(find_unused_parameters=True) plugin = TorchDDPPlugin(find_unused_parameters=not args.grad_checkpoint)
elif args.plugin == "gemini": elif args.plugin == "gemini":
plugin = GeminiPlugin( plugin = GeminiPlugin(
precision=args.mixed_precision, precision=args.mixed_precision,
@ -56,14 +54,6 @@ def train(args):
enable_gradient_accumulation=True, enable_gradient_accumulation=True,
enable_flash_attention=args.use_flash_attn, enable_flash_attention=args.use_flash_attn,
) )
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
precision=args.mixed_precision,
placement_policy="auto",
initial_scale=2**16,
max_norm=args.grad_clip,
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "zero2": elif args.plugin == "zero2":
plugin = LowLevelZeroPlugin( plugin = LowLevelZeroPlugin(
stage=2, stage=2,
@ -92,20 +82,24 @@ def train(args):
parallel_output=False, parallel_output=False,
max_norm=args.grad_clip, max_norm=args.grad_clip,
precision=args.mixed_precision, precision=args.mixed_precision,
microbatch_size=args.microbatch_size,
) )
else: else:
raise ValueError(f"Unknown plugin {args.plugin}") raise ValueError(f"Unknown plugin {args.plugin}")
booster = Booster(plugin=plugin) booster = Booster(plugin=plugin)
ref_booster = Booster(plugin=plugin)
# ====================================================== ref_plugin = HybridParallelPlugin(
# Initialize Model, Objective, Optimizer and LR Scheduler tp_size=args.ref_tp,
# ====================================================== pp_size=1,
# Temp Fix: Disable lazy init due to version conflict zero_stage=args.zero_stage,
# init_ctx = ( enable_flash_attention=args.use_flash_attn,
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext() cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
# ) parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
)
ref_booster = Booster(plugin=ref_plugin)
init_ctx = nullcontext() init_ctx = nullcontext()
with init_ctx: with init_ctx:
@ -130,6 +124,7 @@ def train(args):
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain) ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain)
else: else:
ref_model = None ref_model = None
if args.lora_config is not None: if args.lora_config is not None:
model = convert_to_lora_module(model, lora_config=lora_config) model = convert_to_lora_module(model, lora_config=lora_config)
for name, module in model.named_modules(): for name, module in model.named_modules():
@ -139,7 +134,9 @@ def train(args):
disable_dropout(ref_model) disable_dropout(ref_model)
if args.grad_checkpoint: if args.grad_checkpoint:
# Note, for some models, lora may not be compatible with gradient checkpointing # Make sure gradient checkpointing can be activated.
model.train()
# Note, for some models, lora may not be compatible with gradient checkpointing.
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully") coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
@ -169,7 +166,7 @@ def train(args):
adamw_mode=True, adamw_mode=True,
) )
# configure dataset # Configure dataset
coordinator.print_on_master(f"Load dataset: {args.dataset}") coordinator.print_on_master(f"Load dataset: {args.dataset}")
mode_map = {"train": "train", "valid": "validation", "test": "test"} mode_map = {"train": "train", "valid": "validation", "test": "test"}
train_dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train", mode_map=mode_map) train_dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train", mode_map=mode_map)
@ -213,14 +210,15 @@ def train(args):
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16 default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
torch.set_default_dtype(default_dtype) torch.set_default_dtype(default_dtype)
model, optim, _, train_dataloader, lr_scheduler = booster.boost( model, optim, _, train_dataloader, lr_scheduler = booster.boost(
model=model, model=model,
optimizer=optim, optimizer=optim,
lr_scheduler=lr_scheduler, lr_scheduler=lr_scheduler,
dataloader=train_dataloader, dataloader=train_dataloader,
) )
if ref_model is not None: ref_model, _, _, _, _ = ref_booster.boost(model=ref_model)
ref_model, _, _, _, _ = ref_booster.boost(model=ref_model, dataloader=train_dataloader)
torch.set_default_dtype(torch.float) torch.set_default_dtype(torch.float)
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB") coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
@ -312,7 +310,7 @@ if __name__ == "__main__":
"--plugin", "--plugin",
type=str, type=str,
default="gemini", default="gemini",
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"], choices=["gemini", "zero2", "zero2_cpu", "3d", "ddp"],
help="Choose which plugin to use", help="Choose which plugin to use",
) )
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value") parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
@ -342,22 +340,35 @@ if __name__ == "__main__":
parser.add_argument("--max_length", type=int, default=2048, help="Model max length") parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
parser.add_argument("--max_epochs", type=int, default=3) parser.add_argument("--max_epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument(
"--disable_reference_model",
action="store_true",
default=False,
help="Disable the reference model (enabled by default)",
)
parser.add_argument("--disable_loss_mask", default=False, action="store_true") parser.add_argument("--disable_loss_mask", default=False, action="store_true")
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision") parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path") parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints") parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
parser.add_argument("--lr", type=float, default=5e-6) parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--accumulation_steps", type=int, default=8) parser.add_argument("--accumulation_steps", type=int, default=1)
parser.add_argument("--log_dir", default=None, type=str) parser.add_argument("--log_dir", default=None, type=str)
parser.add_argument("--use_wandb", default=False, action="store_true") parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--grad_checkpoint", default=False, action="store_true") parser.add_argument("--grad_checkpoint", default=False, action="store_true")
parser.add_argument("--use_flash_attn", default=False, action="store_true") parser.add_argument("--use_flash_attn", default=False, action="store_true")
parser.add_argument(
"--microbatch_size",
type=int,
default=2,
help="Micro batch size for PP training. To activate PP training for DPO-like algorithm, you must keep size even and the size should be equal or greater than 2.",
)
# Parameter for reference model
parser.add_argument(
"--disable_reference_model",
action="store_true",
default=False,
help="Disable the reference model (enabled by default)",
)
parser.add_argument(
"--ref_tp",
type=int,
default=1,
help="TP size for reference model; used only when reference model is too large.",
)
args = parser.parse_args() args = parser.parse_args()
# fool proof hyperparameter setup # fool proof hyperparameter setup

@ -68,7 +68,7 @@ def train(args):
Default torch ddp plugin without any acceleration, for Default torch ddp plugin without any acceleration, for
debugging purpose acceleration, for debugging purpose debugging purpose acceleration, for debugging purpose
""" """
plugin = TorchDDPPlugin(find_unused_parameters=True if args.grad_checkpoint is False else False) plugin = TorchDDPPlugin(find_unused_parameters=not args.grad_checkpoint)
elif args.plugin == "gemini": elif args.plugin == "gemini":
plugin = GeminiPlugin( plugin = GeminiPlugin(
precision=args.mixed_precision, precision=args.mixed_precision,

@ -4,7 +4,7 @@ BASE_TEMP_DIR=$BASE_DIR/temp
EXAMPLES_DIR=$BASE_DIR/examples EXAMPLES_DIR=$BASE_DIR/examples
TEST_DATA_DIR=$BASE_DIR/tests/test_data TEST_DATA_DIR=$BASE_DIR/tests/test_data
DATA_SAVE_PATH=$BASE_TEMP_DIR/tests DATA_SAVE_PATH=$BASE_TEMP_DIR/tests
CONFIG_DIR=$BASE_DIR/config CONFIG_DIR=$BASE_DIR/conversation_template
# MODELS=("colossal-llama2" "llama2" "mistral" "chatGLM2" "chatGLM3" "deepseek" "Yi" "baichuan") # for local test # MODELS=("colossal-llama2" "llama2" "mistral" "chatGLM2" "chatGLM3" "deepseek" "Yi" "baichuan") # for local test
MODELS=("colossal-llama2" "llama2" "chatGLM2" "chatGLM3" "deepseek" "Yi") MODELS=("colossal-llama2" "llama2" "chatGLM2" "chatGLM3" "deepseek" "Yi")
@ -39,23 +39,23 @@ get_pretrain() {
get_conversation_template_config() { get_conversation_template_config() {
local model=$1 local model=$1
if [[ $model == "colossal-llama2" ]]; then if [[ $model == "colossal-llama2" ]]; then
echo "$CONFIG_DIR/conversation_template/colossal-llama2.json" echo "$CONFIG_DIR/colossal-llama2.json"
elif [[ $model == "llama2" ]]; then elif [[ $model == "llama2" ]]; then
echo "$CONFIG_DIR/conversation_template/llama2.json" echo "$CONFIG_DIR/llama2.json"
elif [[ $model == "deepseek" ]]; then elif [[ $model == "deepseek" ]]; then
echo "$CONFIG_DIR/conversation_template/deepseek-ai_DeepSeek-V2-Lite.json" echo "$CONFIG_DIR/deepseek-ai_DeepSeek-V2-Lite.json"
elif [[ $model == "mistral" ]]; then elif [[ $model == "mistral" ]]; then
echo "$CONFIG_DIR/conversation_template/mistralai_Mixtral-8x7B-Instruct-v0.1.json" echo "$CONFIG_DIR/mistralai_Mixtral-8x7B-Instruct-v0.1.json"
elif [[ $model == "chatGLM2" ]]; then elif [[ $model == "chatGLM2" ]]; then
echo "$CONFIG_DIR/conversation_template/THUDM_chatglm2-6b.json" echo "$CONFIG_DIR/THUDM_chatglm2-6b.json"
elif [[ $model == "chatGLM3" ]]; then elif [[ $model == "chatGLM3" ]]; then
echo "$CONFIG_DIR/conversation_template/THUDM_chatglm3-6b.json" echo "$CONFIG_DIR/THUDM_chatglm3-6b.json"
elif [[ $model == "phi" ]]; then elif [[ $model == "phi" ]]; then
echo "$CONFIG_DIR/conversation_template/microsoft_phi-2.json" echo "$CONFIG_DIR/microsoft_phi-2.json"
elif [[ $model == "Yi" ]]; then elif [[ $model == "Yi" ]]; then
echo "$CONFIG_DIR/conversation_template/01-ai_Yi-1.5-9B-Chat.json" echo "$CONFIG_DIR/01-ai_Yi-1.5-9B-Chat.json"
elif [[ $model == "baichuan" ]]; then elif [[ $model == "baichuan" ]]; then
echo "$CONFIG_DIR/conversation_template/baichuan-inc_Baichuan2-13B-Chat.json" echo "$CONFIG_DIR/baichuan-inc_Baichuan2-13B-Chat.json"
else else
echo "Unknown model $model" echo "Unknown model $model"
exit 1 exit 1
@ -71,6 +71,7 @@ for model in ${MODELS[@]}; do
rm -rf $SAVE_DIR/arrow rm -rf $SAVE_DIR/arrow
pretrain=$(get_pretrain $model) pretrain=$(get_pretrain $model)
conversation_template_config=$(get_conversation_template_config $model) conversation_template_config=$(get_conversation_template_config $model)
echo $conversation_template_config
python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py --type sft --data_input_dirs $TEST_DATA_DIR/sft \ python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py --type sft --data_input_dirs $TEST_DATA_DIR/sft \
--tokenizer_dir $pretrain \ --tokenizer_dir $pretrain \
--conversation_template_config $conversation_template_config \ --conversation_template_config $conversation_template_config \

@ -271,6 +271,7 @@ class LlamaPipelineForwards:
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None, shard_config: ShardConfig = None,
**kwargs,
): ):
r""" r"""
Args: Args:

Loading…
Cancel
Save