Browse Source

Support overall loss, update KTO logging

colossalchat
YeAnbang 4 months ago
parent
commit
0b2d55c4ab
  1. 19
      applications/ColossalChat/coati/dataset/tokenization_utils.py
  2. 16
      applications/ColossalChat/coati/models/loss.py
  3. 9
      applications/ColossalChat/coati/trainer/dpo.py
  4. 37
      applications/ColossalChat/coati/trainer/kto.py
  5. 12
      applications/ColossalChat/coati/trainer/orpo.py
  6. 12
      applications/ColossalChat/coati/trainer/ppo.py
  7. 14
      applications/ColossalChat/coati/trainer/sft.py
  8. 1
      applications/ColossalChat/examples/README.md
  9. 4
      applications/ColossalChat/examples/inference/inference.py
  10. 104
      applications/ColossalChat/examples/inference/round.txt
  11. 2
      applications/ColossalChat/examples/training_scripts/train_dpo.py
  12. 2
      applications/ColossalChat/examples/training_scripts/train_kto.py
  13. 2
      applications/ColossalChat/examples/training_scripts/train_orpo.py
  14. 2
      applications/ColossalChat/examples/training_scripts/train_ppo.py
  15. 2
      applications/ColossalChat/examples/training_scripts/train_sft.py

19
applications/ColossalChat/coati/dataset/tokenization_utils.py

@ -49,6 +49,10 @@ def tokenize_sft(
messages = data_point["messages"] messages = data_point["messages"]
template = deepcopy(conversation_template) template = deepcopy(conversation_template)
if messages[0]["from"] == "system":
template.system_message = str(messages[0]["content"])
messages.pop(0)
template.messages = [] template.messages = []
for idx, mess in enumerate(messages): for idx, mess in enumerate(messages):
if mess["from"] != template.roles[idx % 2]: if mess["from"] != template.roles[idx % 2]:
@ -148,11 +152,14 @@ def tokenize_prompt(
template = deepcopy(conversation_template) template = deepcopy(conversation_template)
template.messages = [] template.messages = []
if messages[0]["from"] == "system":
template.system_message = str(messages[0]["content"])
messages.pop(0)
for idx, mess in enumerate(messages): for idx, mess in enumerate(messages):
if mess["from"] != template.roles[idx % 2]: if mess["from"] != template.roles[idx % 2]:
raise ValueError( raise ValueError(
f"Message should iterate between user and assistant and starts with a \ f"Message should iterate between user and assistant and starts with a line from the user. Got the following data:\n{messages}"
line from the user. Got the following data:\n{messages}"
) )
template.append_message(mess["from"], mess["content"]) template.append_message(mess["from"], mess["content"])
@ -225,6 +232,10 @@ def tokenize_rlhf(
template = deepcopy(conversation_template) template = deepcopy(conversation_template)
template.clear() template.clear()
if context[0]["from"] == "system":
template.system_message = str(context[0]["content"])
context.pop(0)
for idx, mess in enumerate(context): for idx, mess in enumerate(context):
if mess["from"] != template.roles[idx % 2]: if mess["from"] != template.roles[idx % 2]:
raise ValueError( raise ValueError(
@ -345,6 +356,10 @@ def tokenize_kto(
template = deepcopy(conversation_template) template = deepcopy(conversation_template)
template.clear() template.clear()
if prompt[0]["from"] == "system":
template.system_message = str(prompt[0]["content"])
prompt.pop(0)
if prompt[0].get("from", None) != "user": if prompt[0].get("from", None) != "user":
raise ValueError("conversation should start with user") raise ValueError("conversation should start with user")
if completion.get("from", None) != "assistant": if completion.get("from", None) != "assistant":

16
applications/ColossalChat/coati/models/loss.py

@ -46,7 +46,10 @@ class PolicyLoss(nn.Module):
action_mask: Optional[torch.Tensor] = None, action_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
skip = False skip = False
ratio_ = ((log_probs - old_log_probs) * action_mask).exp() if action_mask is None:
ratio_ = (log_probs - old_log_probs).exp()
else:
ratio_ = ((log_probs - old_log_probs) * action_mask).exp()
# note that if dropout is disabled (recommanded), ratio will always be 1. # note that if dropout is disabled (recommanded), ratio will always be 1.
if ratio_.mean() > self.skip_threshold: if ratio_.mean() > self.skip_threshold:
@ -56,7 +59,10 @@ class PolicyLoss(nn.Module):
surr1 = ratio * advantages surr1 = ratio * advantages
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
loss = -torch.min(surr1, surr2) loss = -torch.min(surr1, surr2)
loss = masked_mean(loss, action_mask) if action_mask is not None:
loss = masked_mean(loss, action_mask)
else:
loss = loss.mean(dim=1)
loss = loss.mean() loss = loss.mean()
return loss, skip, ratio_.max() return loss, skip, ratio_.max()
@ -81,8 +87,10 @@ class ValueLoss(nn.Module):
values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps) values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)
surr1 = (values_clipped - returns) ** 2 surr1 = (values_clipped - returns) ** 2
surr2 = (values - returns) ** 2 surr2 = (values - returns) ** 2
loss = torch.max(surr1, surr2) / torch.sum(action_mask) if action_mask is not None:
loss = torch.sum(loss * action_mask) loss = torch.sum(torch.max(surr1, surr2) / torch.sum(action_mask) * action_mask)
else:
loss = torch.mean(torch.max(surr1, surr2))
return 0.5 * loss return 0.5 * loss

9
applications/ColossalChat/coati/trainer/dpo.py

@ -56,6 +56,7 @@ class DPOTrainer(SLTrainer):
beta: float = 0.1, beta: float = 0.1,
gamma: float = 0.0, gamma: float = 0.0,
length_normalization: bool = False, length_normalization: bool = False,
apply_loss_mask: bool = True,
accumulation_steps: int = 1, accumulation_steps: int = 1,
start_epoch: int = 0, start_epoch: int = 0,
save_interval: int = 0, save_interval: int = 0,
@ -67,6 +68,7 @@ class DPOTrainer(SLTrainer):
self.actor_scheduler = actor_lr_scheduler self.actor_scheduler = actor_lr_scheduler
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.actor_loss_fn = DpoLoss(beta, gamma) self.actor_loss_fn = DpoLoss(beta, gamma)
self.apply_loss_mask = apply_loss_mask
self.save_interval = save_interval self.save_interval = save_interval
self.coordinator = coordinator self.coordinator = coordinator
self.save_dir = save_dir self.save_dir = save_dir
@ -135,6 +137,10 @@ class DPOTrainer(SLTrainer):
batch["reject_attention_mask"], batch["reject_attention_mask"],
batch["reject_loss_mask"], batch["reject_loss_mask"],
) )
if not self.apply_loss_mask:
chosen_loss_mask = chosen_loss_mask.fill_(1.0)
reject_loss_mask = reject_loss_mask.fill_(1.0)
batch_size = chosen_input_ids.size()[0] batch_size = chosen_input_ids.size()[0]
actor_all_logits = self.model( actor_all_logits = self.model(
@ -284,6 +290,9 @@ class DPOTrainer(SLTrainer):
batch["reject_attention_mask"], batch["reject_attention_mask"],
batch["reject_loss_mask"], batch["reject_loss_mask"],
) )
if not self.apply_loss_mask:
chosen_loss_mask = chosen_loss_mask.fill_(1.0)
reject_loss_mask = reject_loss_mask.fill_(1.0)
batch_size = chosen_input_ids.size()[0] batch_size = chosen_input_ids.size()[0]

37
applications/ColossalChat/coati/trainer/kto.py

@ -6,7 +6,7 @@ import os
from typing import Any, Optional from typing import Any, Optional
import torch import torch
import torch.distributed import torch.distributed as dist
from coati.models.loss import KTOLoss from coati.models.loss import KTOLoss
from coati.models.utils import calc_masked_log_probs from coati.models.utils import calc_masked_log_probs
from coati.trainer.utils import all_reduce_mean from coati.trainer.utils import all_reduce_mean
@ -59,6 +59,7 @@ class KTOTrainer(SLTrainer):
beta: float = 0.1, beta: float = 0.1,
desirable_weight: float = 1.0, desirable_weight: float = 1.0,
undesirable_weight: float = 1.0, undesirable_weight: float = 1.0,
apply_loss_mask: bool = True,
accumulation_steps: int = 1, accumulation_steps: int = 1,
start_epoch: int = 0, start_epoch: int = 0,
save_interval: int = 0, save_interval: int = 0,
@ -70,6 +71,7 @@ class KTOTrainer(SLTrainer):
self.actor_scheduler = actor_lr_scheduler self.actor_scheduler = actor_lr_scheduler
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.kto_loss = KTOLoss(beta=beta, desirable_weight=desirable_weight, undesirable_weight=undesirable_weight) self.kto_loss = KTOLoss(beta=beta, desirable_weight=desirable_weight, undesirable_weight=undesirable_weight)
self.apply_loss_mask = apply_loss_mask
self.save_interval = save_interval self.save_interval = save_interval
self.coordinator = coordinator self.coordinator = coordinator
self.save_dir = save_dir self.save_dir = save_dir
@ -134,6 +136,10 @@ class KTOTrainer(SLTrainer):
batch["kl_attention_mask"], batch["kl_attention_mask"],
batch["kl_loss_mask"], batch["kl_loss_mask"],
) )
if not self.apply_loss_mask:
loss_mask = loss_mask.fill_(1.0)
kl_loss_mask = kl_loss_mask.fill_(1.0)
batch_size = input_ids.size()[0] batch_size = input_ids.size()[0]
# actor logits # actor logits
@ -182,8 +188,28 @@ class KTOTrainer(SLTrainer):
# sync # sync
loss_mean = all_reduce_mean(tensor=loss) loss_mean = all_reduce_mean(tensor=loss)
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards.mean()) chosen_reward_mean = chosen_rewards.mean()
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards.mean()) chosen_rewards_list = [
torch.tensor(0, dtype=loss.dtype, device=loss.device) for _ in range(dist.get_world_size())
]
dist.all_gather(chosen_rewards_list, chosen_reward_mean)
rejected_reward_mean = rejected_rewards.mean()
rejected_rewards_list = [
torch.tensor(0, dtype=loss.dtype, device=loss.device) for _ in range(dist.get_world_size())
]
dist.all_gather(rejected_rewards_list, rejected_reward_mean)
chosen_rewards_list = [i for i in chosen_rewards_list if not i.isnan()]
rejected_rewards_list = [i for i in rejected_rewards_list if not i.isnan()]
chosen_rewards_mean = (
torch.stack(chosen_rewards_list).mean()
if len(chosen_rewards_list) > 0
else torch.tensor(torch.nan, dtype=loss.dtype, device=loss.device)
)
rejected_rewards_mean = (
torch.stack(rejected_rewards_list).mean()
if len(rejected_rewards_list) > 0
else torch.tensor(torch.nan, dtype=loss.dtype, device=loss.device)
)
self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item()) self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item()) self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).detach().item()) self.accumulative_meter.add("loss", loss_mean.to(torch.float16).detach().item())
@ -256,6 +282,11 @@ class KTOTrainer(SLTrainer):
batch["kl_attention_mask"], batch["kl_attention_mask"],
batch["kl_loss_mask"], batch["kl_loss_mask"],
) )
if not self.apply_loss_mask:
loss_mask = loss_mask.fill_(1.0)
kl_loss_mask = kl_loss_mask.fill_(1.0)
batch_size = input_ids.size()[0] batch_size = input_ids.size()[0]
# actor logits # actor logits

12
applications/ColossalChat/coati/trainer/orpo.py

@ -52,6 +52,7 @@ class ORPOTrainer(SLTrainer):
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
max_epochs: int = 1, max_epochs: int = 1,
lam: float = 0.1, lam: float = 0.1,
apply_loss_mask: bool = True,
accumulation_steps: int = 1, accumulation_steps: int = 1,
start_epoch: int = 0, start_epoch: int = 0,
save_interval: int = 0, save_interval: int = 0,
@ -67,6 +68,7 @@ class ORPOTrainer(SLTrainer):
self.save_dir = save_dir self.save_dir = save_dir
self.num_train_step = 0 self.num_train_step = 0
self.lam = lam self.lam = lam
self.apply_loss_mask = apply_loss_mask
self.accumulation_steps = accumulation_steps self.accumulation_steps = accumulation_steps
self.device = get_current_device() self.device = get_current_device()
self.accumulative_meter = AccumulativeMeanMeter() self.accumulative_meter = AccumulativeMeanMeter()
@ -130,6 +132,11 @@ class ORPOTrainer(SLTrainer):
batch["reject_attention_mask"], batch["reject_attention_mask"],
batch["reject_loss_mask"], batch["reject_loss_mask"],
) )
if not self.apply_loss_mask:
chosen_loss_mask = chosen_loss_mask.fill_(1.0)
reject_loss_mask = reject_loss_mask.fill_(1.0)
batch_size = chosen_input_ids.size()[0] batch_size = chosen_input_ids.size()[0]
actor_out = self.model( actor_out = self.model(
input_ids=torch.cat([chosen_input_ids, reject_input_ids]), input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
@ -263,6 +270,11 @@ class ORPOTrainer(SLTrainer):
batch["reject_attention_mask"], batch["reject_attention_mask"],
batch["reject_loss_mask"], batch["reject_loss_mask"],
) )
if not self.apply_loss_mask:
chosen_loss_mask = chosen_loss_mask.fill_(1.0)
reject_loss_mask = reject_loss_mask.fill_(1.0)
batch_size = chosen_input_ids.size()[0] batch_size = chosen_input_ids.size()[0]
actor_out = self.model( actor_out = self.model(
input_ids=torch.cat([chosen_input_ids, reject_input_ids]), input_ids=torch.cat([chosen_input_ids, reject_input_ids]),

12
applications/ColossalChat/coati/trainer/ppo.py

@ -102,6 +102,7 @@ class PPOTrainer(OLTrainer):
sample_buffer: bool = False, sample_buffer: bool = False,
dataloader_pin_memory: bool = True, dataloader_pin_memory: bool = True,
offload_inference_models: bool = True, offload_inference_models: bool = True,
apply_loss_mask: bool = True,
accumulation_steps: int = 1, accumulation_steps: int = 1,
save_interval: int = 0, save_interval: int = 0,
save_dir: str = None, save_dir: str = None,
@ -140,6 +141,7 @@ class PPOTrainer(OLTrainer):
self.actor_optim = actor_optim self.actor_optim = actor_optim
self.critic_optim = critic_optim self.critic_optim = critic_optim
self.save_interval = save_interval self.save_interval = save_interval
self.apply_loss_mask = apply_loss_mask
self.coordinator = coordinator self.coordinator = coordinator
self.actor_save_dir = os.path.join(save_dir, "actor") self.actor_save_dir = os.path.join(save_dir, "actor")
self.critic_save_dir = os.path.join(save_dir, "critic") self.critic_save_dir = os.path.join(save_dir, "critic")
@ -229,7 +231,10 @@ class PPOTrainer(OLTrainer):
action_log_probs = calc_action_log_probs(actor_logits, experience.sequences, num_actions) action_log_probs = calc_action_log_probs(actor_logits, experience.sequences, num_actions)
actor_loss, to_skip, max_ratio = self.actor_loss_fn( actor_loss, to_skip, max_ratio = self.actor_loss_fn(
action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask action_log_probs,
experience.action_log_probs,
experience.advantages,
action_mask=experience.action_mask if self.apply_loss_mask else None,
) )
actor_loss = (1 - self.ptx_coef) * actor_loss actor_loss = (1 - self.ptx_coef) * actor_loss
if not to_skip: if not to_skip:
@ -249,7 +254,10 @@ class PPOTrainer(OLTrainer):
input_ids=experience.sequences, attention_mask=experience.attention_mask input_ids=experience.sequences, attention_mask=experience.attention_mask
) # [batch size, prompt_length + response_length] ) # [batch size, prompt_length + response_length]
critic_loss = self.critic_loss_fn( critic_loss = self.critic_loss_fn(
values[:, -num_actions:], experience.values, experience.advantages, action_mask=experience.action_mask values[:, -num_actions:],
experience.values,
experience.advantages,
action_mask=experience.action_mask if self.apply_loss_mask else None,
) )
critic_loss = critic_loss * self.vf_coef critic_loss = critic_loss * self.vf_coef
self.critic_booster.backward(loss=critic_loss, optimizer=self.critic_optim) self.critic_booster.backward(loss=critic_loss, optimizer=self.critic_optim)

14
applications/ColossalChat/coati/trainer/sft.py

@ -41,6 +41,7 @@ class SFTTrainer(SLTrainer):
lr_scheduler: _LRScheduler, lr_scheduler: _LRScheduler,
max_epochs: int = 2, max_epochs: int = 2,
accumulation_steps: int = 8, accumulation_steps: int = 8,
apply_loss_mask: bool = True,
start_epoch=0, start_epoch=0,
save_interval: int = None, save_interval: int = None,
save_dir: str = None, save_dir: str = None,
@ -55,6 +56,7 @@ class SFTTrainer(SLTrainer):
self.coordinator = coordinator self.coordinator = coordinator
self.num_train_step = 0 self.num_train_step = 0
self.num_eval_step = 0 self.num_eval_step = 0
self.apply_loss_mask = apply_loss_mask
self.accumulative_meter = AccumulativeMeanMeter() self.accumulative_meter = AccumulativeMeanMeter()
def _before_fit( def _before_fit(
@ -100,7 +102,11 @@ class SFTTrainer(SLTrainer):
for i, batch in enumerate(self.train_dataloader): for i, batch in enumerate(self.train_dataloader):
batch = to_device(batch, torch.cuda.current_device()) batch = to_device(batch, torch.cuda.current_device())
batch_size = batch["input_ids"].size(0) batch_size = batch["input_ids"].size(0)
outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]) outputs = self.model(
batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"],
)
loss = outputs.loss loss = outputs.loss
self.booster.backward(loss=loss, optimizer=self.optimizer) self.booster.backward(loss=loss, optimizer=self.optimizer)
@ -158,7 +164,11 @@ class SFTTrainer(SLTrainer):
) )
for batch in self.eval_dataloader: for batch in self.eval_dataloader:
batch = to_device(batch, torch.cuda.current_device()) batch = to_device(batch, torch.cuda.current_device())
outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]) outputs = self.model(
batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"],
)
loss_mean = all_reduce_mean(tensor=outputs.loss) loss_mean = all_reduce_mean(tensor=outputs.loss)
self.accumulative_meter.add("loss", loss_mean.item(), count_update=batch["input_ids"].size(0)) self.accumulative_meter.add("loss", loss_mean.item(), count_update=batch["input_ids"].size(0))
step_bar.update() step_bar.update()

1
applications/ColossalChat/examples/README.md

@ -387,6 +387,7 @@ colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile trai
- save_dir: path to store the model checkpoints. - save_dir: path to store the model checkpoints.
- max_length: input will be padded/truncated to max_length before feeding to the model. - max_length: input will be padded/truncated to max_length before feeding to the model.
- max_epochs: number of epochs to train. - max_epochs: number of epochs to train.
- disable_loss_mask: whether to use the loss mask to mask the loss or not. For example, in SFT, if the loss mask is disabled, the model will compute the loss across all tokens in the sequence, if the loss mask is applied, only tokens correspond to the assistant responses will contribute to the final loss.
- batch_size: training batch size. - batch_size: training batch size.
- mixed_precision: precision to use in training. Support 'fp16' and 'bf16'. Note that some devices may not support the 'bf16' option, please refer to [Nvidia](https://developer.nvidia.com/) to check compatibility. - mixed_precision: precision to use in training. Support 'fp16' and 'bf16'. Note that some devices may not support the 'bf16' option, please refer to [Nvidia](https://developer.nvidia.com/) to check compatibility.
- save_interval: save the model weights as well as optimizer/scheduler states every save_interval steps/episodes. - save_interval: save the model weights as well as optimizer/scheduler states every save_interval steps/episodes.

4
applications/ColossalChat/examples/inference/inference.py

@ -53,8 +53,8 @@ def load_model_and_tokenizer(model_path, tokenizer_path, device="cuda", **kwargs
tuple: A tuple containing the loaded model and tokenizer. tuple: A tuple containing the loaded model and tokenizer.
""" """
model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs) model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs, trust_remote_code=True).to(torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
model.to(device) model.to(device)

104
applications/ColossalChat/examples/inference/round.txt

@ -1,104 +0,0 @@
==========
round 1:
<s>[INST] <<SYS>>
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
<</SYS>>
tell me a story [/INST] Great, let’s hear a story. </s>
==========
==========
round 2:
<s>[INST] <<SYS>>
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
<</SYS>>
tell me a story [/INST] Great, let’s hear a story. </s><s>[INST] calculate 1+1 [/INST] 1+1=2 </s>
==========
==========
round 3:
<s>[INST] <<SYS>>
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
<</SYS>>
tell me a story [/INST] Great, let’s hear a story. </s><s>[INST] calculate 1+1 [/INST] 1+1=2 </s><s>[INST] who is the first president of the USA [/INST] The first president of the United States was George Washington. </s>
==========
==========
round 1:
<s>[INST] <<SYS>>
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
<</SYS>>
who is the first president of the USA? [/INST] The first president of the United States was George Washington. </s>
==========
==========
round 2:
<s>[INST] <<SYS>>
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
<</SYS>>
who is the first president of the USA? [/INST] The first president of the United States was George Washington. </s><s>[INST] tell me a story [/INST] One story that might be interesting is the story of how the United States was founded. In 1776, the Thirteen Colonies united together to form the new nation of America. The first president of the new nation was George Washington. The first state was Pennsylvania. The first capital was Philadelphia. The first presidential election was held in 1787, and was between George Read, a Federalist, and John Adams, a Republican. The Federalists wanted to keep the power of the federal government limited, while the Republicans wanted the power to be spread around to the states. The Federalists won, and George Washington was elected president. </s>
==========
==========
round 1:
<s>[INST] <<SYS>>
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
<</SYS>>
tell me a story [/INST] I do enjoy telling stories. Is there a particular story you’d like to hear? </s>
==========
==========
round 2:
<s>[INST] <<SYS>>
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
<</SYS>>
tell me a story [/INST] I do enjoy telling stories. Is there a particular story you’d like to hear? </s><s>[INST] about Donald Trump [/INST] I’d be happy to listen to a story about Donald Trump. I have no idea what a “story” is, but I do know a lot about human life and human behavior. I can try. Do you have a theme or a question you’d like to ask about Donald Trump? </s>
==========
==========
round 3:
<s>[INST] <<SYS>>
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
<</SYS>>
tell me a story [/INST] I do enjoy telling stories. Is there a particular story you’d like to hear? </s><s>[INST] about Donald Trump [/INST] I’d be happy to listen to a story about Donald Trump. I have no idea what a “story” is, but I do know a lot about human life and human behavior. I can try. Do you have a theme or a question you’d like to ask about Donald Trump? </s><s>[INST] Is Donald Trump the president of the United States [/INST] Yes, Donald Trump became the 45th president of the United States in January of 2016. </s>
==========

2
applications/ColossalChat/examples/training_scripts/train_dpo.py

@ -278,6 +278,7 @@ def train(args):
beta=args.beta, beta=args.beta,
gamma=args.gamma, gamma=args.gamma,
length_normalization=args.length_normalization, length_normalization=args.length_normalization,
apply_loss_mask=not args.disable_loss_mask,
) )
trainer.fit( trainer.fit(
@ -346,6 +347,7 @@ if __name__ == "__main__":
default=False, default=False,
help="Disable the reference model (enabled by default)", help="Disable the reference model (enabled by default)",
) )
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision") parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path") parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints") parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")

2
applications/ColossalChat/examples/training_scripts/train_kto.py

@ -297,6 +297,7 @@ def train(args):
beta=args.beta, beta=args.beta,
desirable_weight=args.desirable_weight, desirable_weight=args.desirable_weight,
undesirable_weight=args.undesirable_weight, undesirable_weight=args.undesirable_weight,
apply_loss_mask=not args.disable_loss_mask,
) )
trainer.fit( trainer.fit(
@ -341,6 +342,7 @@ if __name__ == "__main__":
parser.add_argument("--beta", type=float, default=0.1, help="beta in KTO loss") parser.add_argument("--beta", type=float, default=0.1, help="beta in KTO loss")
parser.add_argument("--desirable_weight", type=float, default=1.0, help="desirable_weight in KTO loss") parser.add_argument("--desirable_weight", type=float, default=1.0, help="desirable_weight in KTO loss")
parser.add_argument("--undesirable_weight", type=float, default=1.0, help="undesirable_weight in KTO loss") parser.add_argument("--undesirable_weight", type=float, default=1.0, help="undesirable_weight in KTO loss")
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true") parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2]) parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
parser.add_argument("--zero_cpu_offload", default=False, action="store_true") parser.add_argument("--zero_cpu_offload", default=False, action="store_true")

2
applications/ColossalChat/examples/training_scripts/train_orpo.py

@ -259,6 +259,7 @@ def train(args):
save_dir=args.save_dir, save_dir=args.save_dir,
coordinator=coordinator, coordinator=coordinator,
lam=args.lam, lam=args.lam,
apply_loss_mask=not args.disable_loss_mask,
) )
trainer.fit( trainer.fit(
@ -301,6 +302,7 @@ if __name__ == "__main__":
parser.add_argument("--pp", type=int, default=1) parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--sp", type=int, default=1) parser.add_argument("--sp", type=int, default=1)
parser.add_argument("--lam", type=float, default=0.1, help="lambda in ORPO loss") parser.add_argument("--lam", type=float, default=0.1, help="lambda in ORPO loss")
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true") parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2]) parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
parser.add_argument("--zero_cpu_offload", default=False, action="store_true") parser.add_argument("--zero_cpu_offload", default=False, action="store_true")

2
applications/ColossalChat/examples/training_scripts/train_ppo.py

@ -411,6 +411,7 @@ def train(args):
use_cache=True, use_cache=True,
do_sample=True, do_sample=True,
temperature=0.7, temperature=0.7,
apply_loss_mask=not args.disable_loss_mask,
accumulation_steps=args.accumulation_steps, accumulation_steps=args.accumulation_steps,
save_dir=args.save_path, save_dir=args.save_path,
save_interval=args.save_interval, save_interval=args.save_interval,
@ -498,6 +499,7 @@ if __name__ == "__main__":
parser.add_argument("--critic_lr", type=float, default=9e-6) parser.add_argument("--critic_lr", type=float, default=9e-6)
parser.add_argument("--kl_coef", type=float, default=0.1) parser.add_argument("--kl_coef", type=float, default=0.1)
parser.add_argument("--ptx_coef", type=float, default=0.0) parser.add_argument("--ptx_coef", type=float, default=0.0)
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
parser.add_argument("--max_length", type=int, default=2048) parser.add_argument("--max_length", type=int, default=2048)
parser.add_argument("--max_seq_len", type=int, default=256) parser.add_argument("--max_seq_len", type=int, default=256)
parser.add_argument("--log_dir", default="logs", type=str) parser.add_argument("--log_dir", default="logs", type=str)

2
applications/ColossalChat/examples/training_scripts/train_sft.py

@ -272,6 +272,7 @@ def train(args):
lr_scheduler=lr_scheduler, lr_scheduler=lr_scheduler,
max_epochs=args.max_epochs, max_epochs=args.max_epochs,
accumulation_steps=args.accumulation_steps, accumulation_steps=args.accumulation_steps,
apply_loss_mask=not args.disable_loss_mask,
start_epoch=start_epoch, start_epoch=start_epoch,
save_interval=args.save_interval, save_interval=args.save_interval,
save_dir=args.save_path, save_dir=args.save_path,
@ -317,6 +318,7 @@ if __name__ == "__main__":
parser.add_argument("--tp", type=int, default=1) parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--pp", type=int, default=1) parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--sp", type=int, default=1) parser.add_argument("--sp", type=int, default=1)
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true") parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2]) parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
parser.add_argument("--zero_cpu_offload", default=False, action="store_true") parser.add_argument("--zero_cpu_offload", default=False, action="store_true")

Loading…
Cancel
Save