mirror of https://github.com/hpcaitech/ColossalAI
Support overall loss, update KTO logging
parent
75c963686f
commit
0b2d55c4ab
|
@ -49,6 +49,10 @@ def tokenize_sft(
|
|||
|
||||
messages = data_point["messages"]
|
||||
template = deepcopy(conversation_template)
|
||||
|
||||
if messages[0]["from"] == "system":
|
||||
template.system_message = str(messages[0]["content"])
|
||||
messages.pop(0)
|
||||
template.messages = []
|
||||
for idx, mess in enumerate(messages):
|
||||
if mess["from"] != template.roles[idx % 2]:
|
||||
|
@ -148,11 +152,14 @@ def tokenize_prompt(
|
|||
template = deepcopy(conversation_template)
|
||||
template.messages = []
|
||||
|
||||
if messages[0]["from"] == "system":
|
||||
template.system_message = str(messages[0]["content"])
|
||||
messages.pop(0)
|
||||
|
||||
for idx, mess in enumerate(messages):
|
||||
if mess["from"] != template.roles[idx % 2]:
|
||||
raise ValueError(
|
||||
f"Message should iterate between user and assistant and starts with a \
|
||||
line from the user. Got the following data:\n{messages}"
|
||||
f"Message should iterate between user and assistant and starts with a line from the user. Got the following data:\n{messages}"
|
||||
)
|
||||
template.append_message(mess["from"], mess["content"])
|
||||
|
||||
|
@ -225,6 +232,10 @@ def tokenize_rlhf(
|
|||
template = deepcopy(conversation_template)
|
||||
template.clear()
|
||||
|
||||
if context[0]["from"] == "system":
|
||||
template.system_message = str(context[0]["content"])
|
||||
context.pop(0)
|
||||
|
||||
for idx, mess in enumerate(context):
|
||||
if mess["from"] != template.roles[idx % 2]:
|
||||
raise ValueError(
|
||||
|
@ -345,6 +356,10 @@ def tokenize_kto(
|
|||
template = deepcopy(conversation_template)
|
||||
template.clear()
|
||||
|
||||
if prompt[0]["from"] == "system":
|
||||
template.system_message = str(prompt[0]["content"])
|
||||
prompt.pop(0)
|
||||
|
||||
if prompt[0].get("from", None) != "user":
|
||||
raise ValueError("conversation should start with user")
|
||||
if completion.get("from", None) != "assistant":
|
||||
|
|
|
@ -46,7 +46,10 @@ class PolicyLoss(nn.Module):
|
|||
action_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
skip = False
|
||||
ratio_ = ((log_probs - old_log_probs) * action_mask).exp()
|
||||
if action_mask is None:
|
||||
ratio_ = (log_probs - old_log_probs).exp()
|
||||
else:
|
||||
ratio_ = ((log_probs - old_log_probs) * action_mask).exp()
|
||||
|
||||
# note that if dropout is disabled (recommanded), ratio will always be 1.
|
||||
if ratio_.mean() > self.skip_threshold:
|
||||
|
@ -56,7 +59,10 @@ class PolicyLoss(nn.Module):
|
|||
surr1 = ratio * advantages
|
||||
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
|
||||
loss = -torch.min(surr1, surr2)
|
||||
loss = masked_mean(loss, action_mask)
|
||||
if action_mask is not None:
|
||||
loss = masked_mean(loss, action_mask)
|
||||
else:
|
||||
loss = loss.mean(dim=1)
|
||||
loss = loss.mean()
|
||||
return loss, skip, ratio_.max()
|
||||
|
||||
|
@ -81,8 +87,10 @@ class ValueLoss(nn.Module):
|
|||
values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)
|
||||
surr1 = (values_clipped - returns) ** 2
|
||||
surr2 = (values - returns) ** 2
|
||||
loss = torch.max(surr1, surr2) / torch.sum(action_mask)
|
||||
loss = torch.sum(loss * action_mask)
|
||||
if action_mask is not None:
|
||||
loss = torch.sum(torch.max(surr1, surr2) / torch.sum(action_mask) * action_mask)
|
||||
else:
|
||||
loss = torch.mean(torch.max(surr1, surr2))
|
||||
return 0.5 * loss
|
||||
|
||||
|
||||
|
|
|
@ -56,6 +56,7 @@ class DPOTrainer(SLTrainer):
|
|||
beta: float = 0.1,
|
||||
gamma: float = 0.0,
|
||||
length_normalization: bool = False,
|
||||
apply_loss_mask: bool = True,
|
||||
accumulation_steps: int = 1,
|
||||
start_epoch: int = 0,
|
||||
save_interval: int = 0,
|
||||
|
@ -67,6 +68,7 @@ class DPOTrainer(SLTrainer):
|
|||
self.actor_scheduler = actor_lr_scheduler
|
||||
self.tokenizer = tokenizer
|
||||
self.actor_loss_fn = DpoLoss(beta, gamma)
|
||||
self.apply_loss_mask = apply_loss_mask
|
||||
self.save_interval = save_interval
|
||||
self.coordinator = coordinator
|
||||
self.save_dir = save_dir
|
||||
|
@ -135,6 +137,10 @@ class DPOTrainer(SLTrainer):
|
|||
batch["reject_attention_mask"],
|
||||
batch["reject_loss_mask"],
|
||||
)
|
||||
if not self.apply_loss_mask:
|
||||
chosen_loss_mask = chosen_loss_mask.fill_(1.0)
|
||||
reject_loss_mask = reject_loss_mask.fill_(1.0)
|
||||
|
||||
batch_size = chosen_input_ids.size()[0]
|
||||
|
||||
actor_all_logits = self.model(
|
||||
|
@ -284,6 +290,9 @@ class DPOTrainer(SLTrainer):
|
|||
batch["reject_attention_mask"],
|
||||
batch["reject_loss_mask"],
|
||||
)
|
||||
if not self.apply_loss_mask:
|
||||
chosen_loss_mask = chosen_loss_mask.fill_(1.0)
|
||||
reject_loss_mask = reject_loss_mask.fill_(1.0)
|
||||
|
||||
batch_size = chosen_input_ids.size()[0]
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ import os
|
|||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.distributed as dist
|
||||
from coati.models.loss import KTOLoss
|
||||
from coati.models.utils import calc_masked_log_probs
|
||||
from coati.trainer.utils import all_reduce_mean
|
||||
|
@ -59,6 +59,7 @@ class KTOTrainer(SLTrainer):
|
|||
beta: float = 0.1,
|
||||
desirable_weight: float = 1.0,
|
||||
undesirable_weight: float = 1.0,
|
||||
apply_loss_mask: bool = True,
|
||||
accumulation_steps: int = 1,
|
||||
start_epoch: int = 0,
|
||||
save_interval: int = 0,
|
||||
|
@ -70,6 +71,7 @@ class KTOTrainer(SLTrainer):
|
|||
self.actor_scheduler = actor_lr_scheduler
|
||||
self.tokenizer = tokenizer
|
||||
self.kto_loss = KTOLoss(beta=beta, desirable_weight=desirable_weight, undesirable_weight=undesirable_weight)
|
||||
self.apply_loss_mask = apply_loss_mask
|
||||
self.save_interval = save_interval
|
||||
self.coordinator = coordinator
|
||||
self.save_dir = save_dir
|
||||
|
@ -134,6 +136,10 @@ class KTOTrainer(SLTrainer):
|
|||
batch["kl_attention_mask"],
|
||||
batch["kl_loss_mask"],
|
||||
)
|
||||
if not self.apply_loss_mask:
|
||||
loss_mask = loss_mask.fill_(1.0)
|
||||
kl_loss_mask = kl_loss_mask.fill_(1.0)
|
||||
|
||||
batch_size = input_ids.size()[0]
|
||||
|
||||
# actor logits
|
||||
|
@ -182,8 +188,28 @@ class KTOTrainer(SLTrainer):
|
|||
|
||||
# sync
|
||||
loss_mean = all_reduce_mean(tensor=loss)
|
||||
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards.mean())
|
||||
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards.mean())
|
||||
chosen_reward_mean = chosen_rewards.mean()
|
||||
chosen_rewards_list = [
|
||||
torch.tensor(0, dtype=loss.dtype, device=loss.device) for _ in range(dist.get_world_size())
|
||||
]
|
||||
dist.all_gather(chosen_rewards_list, chosen_reward_mean)
|
||||
rejected_reward_mean = rejected_rewards.mean()
|
||||
rejected_rewards_list = [
|
||||
torch.tensor(0, dtype=loss.dtype, device=loss.device) for _ in range(dist.get_world_size())
|
||||
]
|
||||
dist.all_gather(rejected_rewards_list, rejected_reward_mean)
|
||||
chosen_rewards_list = [i for i in chosen_rewards_list if not i.isnan()]
|
||||
rejected_rewards_list = [i for i in rejected_rewards_list if not i.isnan()]
|
||||
chosen_rewards_mean = (
|
||||
torch.stack(chosen_rewards_list).mean()
|
||||
if len(chosen_rewards_list) > 0
|
||||
else torch.tensor(torch.nan, dtype=loss.dtype, device=loss.device)
|
||||
)
|
||||
rejected_rewards_mean = (
|
||||
torch.stack(rejected_rewards_list).mean()
|
||||
if len(rejected_rewards_list) > 0
|
||||
else torch.tensor(torch.nan, dtype=loss.dtype, device=loss.device)
|
||||
)
|
||||
self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
|
||||
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
|
||||
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).detach().item())
|
||||
|
@ -256,6 +282,11 @@ class KTOTrainer(SLTrainer):
|
|||
batch["kl_attention_mask"],
|
||||
batch["kl_loss_mask"],
|
||||
)
|
||||
|
||||
if not self.apply_loss_mask:
|
||||
loss_mask = loss_mask.fill_(1.0)
|
||||
kl_loss_mask = kl_loss_mask.fill_(1.0)
|
||||
|
||||
batch_size = input_ids.size()[0]
|
||||
|
||||
# actor logits
|
||||
|
|
|
@ -52,6 +52,7 @@ class ORPOTrainer(SLTrainer):
|
|||
tokenizer: PreTrainedTokenizerBase,
|
||||
max_epochs: int = 1,
|
||||
lam: float = 0.1,
|
||||
apply_loss_mask: bool = True,
|
||||
accumulation_steps: int = 1,
|
||||
start_epoch: int = 0,
|
||||
save_interval: int = 0,
|
||||
|
@ -67,6 +68,7 @@ class ORPOTrainer(SLTrainer):
|
|||
self.save_dir = save_dir
|
||||
self.num_train_step = 0
|
||||
self.lam = lam
|
||||
self.apply_loss_mask = apply_loss_mask
|
||||
self.accumulation_steps = accumulation_steps
|
||||
self.device = get_current_device()
|
||||
self.accumulative_meter = AccumulativeMeanMeter()
|
||||
|
@ -130,6 +132,11 @@ class ORPOTrainer(SLTrainer):
|
|||
batch["reject_attention_mask"],
|
||||
batch["reject_loss_mask"],
|
||||
)
|
||||
|
||||
if not self.apply_loss_mask:
|
||||
chosen_loss_mask = chosen_loss_mask.fill_(1.0)
|
||||
reject_loss_mask = reject_loss_mask.fill_(1.0)
|
||||
|
||||
batch_size = chosen_input_ids.size()[0]
|
||||
actor_out = self.model(
|
||||
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
|
||||
|
@ -263,6 +270,11 @@ class ORPOTrainer(SLTrainer):
|
|||
batch["reject_attention_mask"],
|
||||
batch["reject_loss_mask"],
|
||||
)
|
||||
|
||||
if not self.apply_loss_mask:
|
||||
chosen_loss_mask = chosen_loss_mask.fill_(1.0)
|
||||
reject_loss_mask = reject_loss_mask.fill_(1.0)
|
||||
|
||||
batch_size = chosen_input_ids.size()[0]
|
||||
actor_out = self.model(
|
||||
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
|
||||
|
|
|
@ -102,6 +102,7 @@ class PPOTrainer(OLTrainer):
|
|||
sample_buffer: bool = False,
|
||||
dataloader_pin_memory: bool = True,
|
||||
offload_inference_models: bool = True,
|
||||
apply_loss_mask: bool = True,
|
||||
accumulation_steps: int = 1,
|
||||
save_interval: int = 0,
|
||||
save_dir: str = None,
|
||||
|
@ -140,6 +141,7 @@ class PPOTrainer(OLTrainer):
|
|||
self.actor_optim = actor_optim
|
||||
self.critic_optim = critic_optim
|
||||
self.save_interval = save_interval
|
||||
self.apply_loss_mask = apply_loss_mask
|
||||
self.coordinator = coordinator
|
||||
self.actor_save_dir = os.path.join(save_dir, "actor")
|
||||
self.critic_save_dir = os.path.join(save_dir, "critic")
|
||||
|
@ -229,7 +231,10 @@ class PPOTrainer(OLTrainer):
|
|||
action_log_probs = calc_action_log_probs(actor_logits, experience.sequences, num_actions)
|
||||
|
||||
actor_loss, to_skip, max_ratio = self.actor_loss_fn(
|
||||
action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
|
||||
action_log_probs,
|
||||
experience.action_log_probs,
|
||||
experience.advantages,
|
||||
action_mask=experience.action_mask if self.apply_loss_mask else None,
|
||||
)
|
||||
actor_loss = (1 - self.ptx_coef) * actor_loss
|
||||
if not to_skip:
|
||||
|
@ -249,7 +254,10 @@ class PPOTrainer(OLTrainer):
|
|||
input_ids=experience.sequences, attention_mask=experience.attention_mask
|
||||
) # [batch size, prompt_length + response_length]
|
||||
critic_loss = self.critic_loss_fn(
|
||||
values[:, -num_actions:], experience.values, experience.advantages, action_mask=experience.action_mask
|
||||
values[:, -num_actions:],
|
||||
experience.values,
|
||||
experience.advantages,
|
||||
action_mask=experience.action_mask if self.apply_loss_mask else None,
|
||||
)
|
||||
critic_loss = critic_loss * self.vf_coef
|
||||
self.critic_booster.backward(loss=critic_loss, optimizer=self.critic_optim)
|
||||
|
|
|
@ -41,6 +41,7 @@ class SFTTrainer(SLTrainer):
|
|||
lr_scheduler: _LRScheduler,
|
||||
max_epochs: int = 2,
|
||||
accumulation_steps: int = 8,
|
||||
apply_loss_mask: bool = True,
|
||||
start_epoch=0,
|
||||
save_interval: int = None,
|
||||
save_dir: str = None,
|
||||
|
@ -55,6 +56,7 @@ class SFTTrainer(SLTrainer):
|
|||
self.coordinator = coordinator
|
||||
self.num_train_step = 0
|
||||
self.num_eval_step = 0
|
||||
self.apply_loss_mask = apply_loss_mask
|
||||
self.accumulative_meter = AccumulativeMeanMeter()
|
||||
|
||||
def _before_fit(
|
||||
|
@ -100,7 +102,11 @@ class SFTTrainer(SLTrainer):
|
|||
for i, batch in enumerate(self.train_dataloader):
|
||||
batch = to_device(batch, torch.cuda.current_device())
|
||||
batch_size = batch["input_ids"].size(0)
|
||||
outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
|
||||
outputs = self.model(
|
||||
batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"],
|
||||
)
|
||||
loss = outputs.loss
|
||||
|
||||
self.booster.backward(loss=loss, optimizer=self.optimizer)
|
||||
|
@ -158,7 +164,11 @@ class SFTTrainer(SLTrainer):
|
|||
)
|
||||
for batch in self.eval_dataloader:
|
||||
batch = to_device(batch, torch.cuda.current_device())
|
||||
outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
|
||||
outputs = self.model(
|
||||
batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"],
|
||||
)
|
||||
loss_mean = all_reduce_mean(tensor=outputs.loss)
|
||||
self.accumulative_meter.add("loss", loss_mean.item(), count_update=batch["input_ids"].size(0))
|
||||
step_bar.update()
|
||||
|
|
|
@ -387,6 +387,7 @@ colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile trai
|
|||
- save_dir: path to store the model checkpoints.
|
||||
- max_length: input will be padded/truncated to max_length before feeding to the model.
|
||||
- max_epochs: number of epochs to train.
|
||||
- disable_loss_mask: whether to use the loss mask to mask the loss or not. For example, in SFT, if the loss mask is disabled, the model will compute the loss across all tokens in the sequence, if the loss mask is applied, only tokens correspond to the assistant responses will contribute to the final loss.
|
||||
- batch_size: training batch size.
|
||||
- mixed_precision: precision to use in training. Support 'fp16' and 'bf16'. Note that some devices may not support the 'bf16' option, please refer to [Nvidia](https://developer.nvidia.com/) to check compatibility.
|
||||
- save_interval: save the model weights as well as optimizer/scheduler states every save_interval steps/episodes.
|
||||
|
|
|
@ -53,8 +53,8 @@ def load_model_and_tokenizer(model_path, tokenizer_path, device="cuda", **kwargs
|
|||
tuple: A tuple containing the loaded model and tokenizer.
|
||||
"""
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs)
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs, trust_remote_code=True).to(torch.bfloat16)
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
model.to(device)
|
||||
|
||||
|
|
|
@ -1,104 +0,0 @@
|
|||
|
||||
|
||||
==========
|
||||
round 1:
|
||||
<s>[INST] <<SYS>>
|
||||
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
|
||||
|
||||
|
||||
<</SYS>>
|
||||
|
||||
tell me a story [/INST] Great, let’s hear a story. </s>
|
||||
|
||||
==========
|
||||
|
||||
|
||||
==========
|
||||
round 2:
|
||||
<s>[INST] <<SYS>>
|
||||
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
|
||||
|
||||
|
||||
<</SYS>>
|
||||
|
||||
tell me a story [/INST] Great, let’s hear a story. </s><s>[INST] calculate 1+1 [/INST] 1+1=2 </s>
|
||||
|
||||
==========
|
||||
|
||||
|
||||
==========
|
||||
round 3:
|
||||
<s>[INST] <<SYS>>
|
||||
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
|
||||
|
||||
|
||||
<</SYS>>
|
||||
|
||||
tell me a story [/INST] Great, let’s hear a story. </s><s>[INST] calculate 1+1 [/INST] 1+1=2 </s><s>[INST] who is the first president of the USA [/INST] The first president of the United States was George Washington. </s>
|
||||
|
||||
==========
|
||||
|
||||
|
||||
==========
|
||||
round 1:
|
||||
<s>[INST] <<SYS>>
|
||||
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
|
||||
|
||||
|
||||
<</SYS>>
|
||||
|
||||
who is the first president of the USA? [/INST] The first president of the United States was George Washington. </s>
|
||||
|
||||
==========
|
||||
|
||||
|
||||
==========
|
||||
round 2:
|
||||
<s>[INST] <<SYS>>
|
||||
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
|
||||
|
||||
|
||||
<</SYS>>
|
||||
|
||||
who is the first president of the USA? [/INST] The first president of the United States was George Washington. </s><s>[INST] tell me a story [/INST] One story that might be interesting is the story of how the United States was founded. In 1776, the Thirteen Colonies united together to form the new nation of America. The first president of the new nation was George Washington. The first state was Pennsylvania. The first capital was Philadelphia. The first presidential election was held in 1787, and was between George Read, a Federalist, and John Adams, a Republican. The Federalists wanted to keep the power of the federal government limited, while the Republicans wanted the power to be spread around to the states. The Federalists won, and George Washington was elected president. </s>
|
||||
|
||||
==========
|
||||
|
||||
|
||||
==========
|
||||
round 1:
|
||||
<s>[INST] <<SYS>>
|
||||
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
|
||||
|
||||
|
||||
<</SYS>>
|
||||
|
||||
tell me a story [/INST] I do enjoy telling stories. Is there a particular story you’d like to hear? </s>
|
||||
|
||||
==========
|
||||
|
||||
|
||||
==========
|
||||
round 2:
|
||||
<s>[INST] <<SYS>>
|
||||
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
|
||||
|
||||
|
||||
<</SYS>>
|
||||
|
||||
tell me a story [/INST] I do enjoy telling stories. Is there a particular story you’d like to hear? </s><s>[INST] about Donald Trump [/INST] I’d be happy to listen to a story about Donald Trump. I have no idea what a “story” is, but I do know a lot about human life and human behavior. I can try. Do you have a theme or a question you’d like to ask about Donald Trump? </s>
|
||||
|
||||
==========
|
||||
|
||||
|
||||
==========
|
||||
round 3:
|
||||
<s>[INST] <<SYS>>
|
||||
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
|
||||
|
||||
|
||||
<</SYS>>
|
||||
|
||||
tell me a story [/INST] I do enjoy telling stories. Is there a particular story you’d like to hear? </s><s>[INST] about Donald Trump [/INST] I’d be happy to listen to a story about Donald Trump. I have no idea what a “story” is, but I do know a lot about human life and human behavior. I can try. Do you have a theme or a question you’d like to ask about Donald Trump? </s><s>[INST] Is Donald Trump the president of the United States [/INST] Yes, Donald Trump became the 45th president of the United States in January of 2016. </s>
|
||||
|
||||
==========
|
|
@ -278,6 +278,7 @@ def train(args):
|
|||
beta=args.beta,
|
||||
gamma=args.gamma,
|
||||
length_normalization=args.length_normalization,
|
||||
apply_loss_mask=not args.disable_loss_mask,
|
||||
)
|
||||
|
||||
trainer.fit(
|
||||
|
@ -346,6 +347,7 @@ if __name__ == "__main__":
|
|||
default=False,
|
||||
help="Disable the reference model (enabled by default)",
|
||||
)
|
||||
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
|
||||
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
|
||||
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
|
||||
|
|
|
@ -297,6 +297,7 @@ def train(args):
|
|||
beta=args.beta,
|
||||
desirable_weight=args.desirable_weight,
|
||||
undesirable_weight=args.undesirable_weight,
|
||||
apply_loss_mask=not args.disable_loss_mask,
|
||||
)
|
||||
|
||||
trainer.fit(
|
||||
|
@ -341,6 +342,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--beta", type=float, default=0.1, help="beta in KTO loss")
|
||||
parser.add_argument("--desirable_weight", type=float, default=1.0, help="desirable_weight in KTO loss")
|
||||
parser.add_argument("--undesirable_weight", type=float, default=1.0, help="undesirable_weight in KTO loss")
|
||||
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
|
||||
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
|
||||
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
|
||||
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
|
||||
|
|
|
@ -259,6 +259,7 @@ def train(args):
|
|||
save_dir=args.save_dir,
|
||||
coordinator=coordinator,
|
||||
lam=args.lam,
|
||||
apply_loss_mask=not args.disable_loss_mask,
|
||||
)
|
||||
|
||||
trainer.fit(
|
||||
|
@ -301,6 +302,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--pp", type=int, default=1)
|
||||
parser.add_argument("--sp", type=int, default=1)
|
||||
parser.add_argument("--lam", type=float, default=0.1, help="lambda in ORPO loss")
|
||||
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
|
||||
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
|
||||
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
|
||||
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
|
||||
|
|
|
@ -411,6 +411,7 @@ def train(args):
|
|||
use_cache=True,
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
apply_loss_mask=not args.disable_loss_mask,
|
||||
accumulation_steps=args.accumulation_steps,
|
||||
save_dir=args.save_path,
|
||||
save_interval=args.save_interval,
|
||||
|
@ -498,6 +499,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--critic_lr", type=float, default=9e-6)
|
||||
parser.add_argument("--kl_coef", type=float, default=0.1)
|
||||
parser.add_argument("--ptx_coef", type=float, default=0.0)
|
||||
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
|
||||
parser.add_argument("--max_length", type=int, default=2048)
|
||||
parser.add_argument("--max_seq_len", type=int, default=256)
|
||||
parser.add_argument("--log_dir", default="logs", type=str)
|
||||
|
|
|
@ -272,6 +272,7 @@ def train(args):
|
|||
lr_scheduler=lr_scheduler,
|
||||
max_epochs=args.max_epochs,
|
||||
accumulation_steps=args.accumulation_steps,
|
||||
apply_loss_mask=not args.disable_loss_mask,
|
||||
start_epoch=start_epoch,
|
||||
save_interval=args.save_interval,
|
||||
save_dir=args.save_path,
|
||||
|
@ -317,6 +318,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--pp", type=int, default=1)
|
||||
parser.add_argument("--sp", type=int, default=1)
|
||||
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
|
||||
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
|
||||
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
|
||||
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
|
||||
|
|
Loading…
Reference in New Issue