diff --git a/applications/ColossalChat/coati/dataset/tokenization_utils.py b/applications/ColossalChat/coati/dataset/tokenization_utils.py
index 9eb2eba87..4f890ffc9 100755
--- a/applications/ColossalChat/coati/dataset/tokenization_utils.py
+++ b/applications/ColossalChat/coati/dataset/tokenization_utils.py
@@ -49,6 +49,10 @@ def tokenize_sft(
messages = data_point["messages"]
template = deepcopy(conversation_template)
+
+ if messages[0]["from"] == "system":
+ template.system_message = str(messages[0]["content"])
+ messages.pop(0)
template.messages = []
for idx, mess in enumerate(messages):
if mess["from"] != template.roles[idx % 2]:
@@ -148,11 +152,14 @@ def tokenize_prompt(
template = deepcopy(conversation_template)
template.messages = []
+ if messages[0]["from"] == "system":
+ template.system_message = str(messages[0]["content"])
+ messages.pop(0)
+
for idx, mess in enumerate(messages):
if mess["from"] != template.roles[idx % 2]:
raise ValueError(
- f"Message should iterate between user and assistant and starts with a \
- line from the user. Got the following data:\n{messages}"
+ f"Message should iterate between user and assistant and starts with a line from the user. Got the following data:\n{messages}"
)
template.append_message(mess["from"], mess["content"])
@@ -225,6 +232,10 @@ def tokenize_rlhf(
template = deepcopy(conversation_template)
template.clear()
+ if context[0]["from"] == "system":
+ template.system_message = str(context[0]["content"])
+ context.pop(0)
+
for idx, mess in enumerate(context):
if mess["from"] != template.roles[idx % 2]:
raise ValueError(
@@ -345,6 +356,10 @@ def tokenize_kto(
template = deepcopy(conversation_template)
template.clear()
+ if prompt[0]["from"] == "system":
+ template.system_message = str(prompt[0]["content"])
+ prompt.pop(0)
+
if prompt[0].get("from", None) != "user":
raise ValueError("conversation should start with user")
if completion.get("from", None) != "assistant":
diff --git a/applications/ColossalChat/coati/models/loss.py b/applications/ColossalChat/coati/models/loss.py
index 840cca074..bd0bbd36b 100755
--- a/applications/ColossalChat/coati/models/loss.py
+++ b/applications/ColossalChat/coati/models/loss.py
@@ -46,7 +46,10 @@ class PolicyLoss(nn.Module):
action_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
skip = False
- ratio_ = ((log_probs - old_log_probs) * action_mask).exp()
+ if action_mask is None:
+ ratio_ = (log_probs - old_log_probs).exp()
+ else:
+ ratio_ = ((log_probs - old_log_probs) * action_mask).exp()
# note that if dropout is disabled (recommanded), ratio will always be 1.
if ratio_.mean() > self.skip_threshold:
@@ -56,7 +59,10 @@ class PolicyLoss(nn.Module):
surr1 = ratio * advantages
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
loss = -torch.min(surr1, surr2)
- loss = masked_mean(loss, action_mask)
+ if action_mask is not None:
+ loss = masked_mean(loss, action_mask)
+ else:
+ loss = loss.mean(dim=1)
loss = loss.mean()
return loss, skip, ratio_.max()
@@ -81,8 +87,10 @@ class ValueLoss(nn.Module):
values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)
surr1 = (values_clipped - returns) ** 2
surr2 = (values - returns) ** 2
- loss = torch.max(surr1, surr2) / torch.sum(action_mask)
- loss = torch.sum(loss * action_mask)
+ if action_mask is not None:
+ loss = torch.sum(torch.max(surr1, surr2) / torch.sum(action_mask) * action_mask)
+ else:
+ loss = torch.mean(torch.max(surr1, surr2))
return 0.5 * loss
diff --git a/applications/ColossalChat/coati/trainer/dpo.py b/applications/ColossalChat/coati/trainer/dpo.py
index c7ef2be8f..24ddca654 100755
--- a/applications/ColossalChat/coati/trainer/dpo.py
+++ b/applications/ColossalChat/coati/trainer/dpo.py
@@ -56,6 +56,7 @@ class DPOTrainer(SLTrainer):
beta: float = 0.1,
gamma: float = 0.0,
length_normalization: bool = False,
+ apply_loss_mask: bool = True,
accumulation_steps: int = 1,
start_epoch: int = 0,
save_interval: int = 0,
@@ -67,6 +68,7 @@ class DPOTrainer(SLTrainer):
self.actor_scheduler = actor_lr_scheduler
self.tokenizer = tokenizer
self.actor_loss_fn = DpoLoss(beta, gamma)
+ self.apply_loss_mask = apply_loss_mask
self.save_interval = save_interval
self.coordinator = coordinator
self.save_dir = save_dir
@@ -135,6 +137,10 @@ class DPOTrainer(SLTrainer):
batch["reject_attention_mask"],
batch["reject_loss_mask"],
)
+ if not self.apply_loss_mask:
+ chosen_loss_mask = chosen_loss_mask.fill_(1.0)
+ reject_loss_mask = reject_loss_mask.fill_(1.0)
+
batch_size = chosen_input_ids.size()[0]
actor_all_logits = self.model(
@@ -284,6 +290,9 @@ class DPOTrainer(SLTrainer):
batch["reject_attention_mask"],
batch["reject_loss_mask"],
)
+ if not self.apply_loss_mask:
+ chosen_loss_mask = chosen_loss_mask.fill_(1.0)
+ reject_loss_mask = reject_loss_mask.fill_(1.0)
batch_size = chosen_input_ids.size()[0]
diff --git a/applications/ColossalChat/coati/trainer/kto.py b/applications/ColossalChat/coati/trainer/kto.py
index 8ab0bc66b..6462ba816 100755
--- a/applications/ColossalChat/coati/trainer/kto.py
+++ b/applications/ColossalChat/coati/trainer/kto.py
@@ -6,7 +6,7 @@ import os
from typing import Any, Optional
import torch
-import torch.distributed
+import torch.distributed as dist
from coati.models.loss import KTOLoss
from coati.models.utils import calc_masked_log_probs
from coati.trainer.utils import all_reduce_mean
@@ -59,6 +59,7 @@ class KTOTrainer(SLTrainer):
beta: float = 0.1,
desirable_weight: float = 1.0,
undesirable_weight: float = 1.0,
+ apply_loss_mask: bool = True,
accumulation_steps: int = 1,
start_epoch: int = 0,
save_interval: int = 0,
@@ -70,6 +71,7 @@ class KTOTrainer(SLTrainer):
self.actor_scheduler = actor_lr_scheduler
self.tokenizer = tokenizer
self.kto_loss = KTOLoss(beta=beta, desirable_weight=desirable_weight, undesirable_weight=undesirable_weight)
+ self.apply_loss_mask = apply_loss_mask
self.save_interval = save_interval
self.coordinator = coordinator
self.save_dir = save_dir
@@ -134,6 +136,10 @@ class KTOTrainer(SLTrainer):
batch["kl_attention_mask"],
batch["kl_loss_mask"],
)
+ if not self.apply_loss_mask:
+ loss_mask = loss_mask.fill_(1.0)
+ kl_loss_mask = kl_loss_mask.fill_(1.0)
+
batch_size = input_ids.size()[0]
# actor logits
@@ -182,8 +188,28 @@ class KTOTrainer(SLTrainer):
# sync
loss_mean = all_reduce_mean(tensor=loss)
- chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards.mean())
- rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards.mean())
+ chosen_reward_mean = chosen_rewards.mean()
+ chosen_rewards_list = [
+ torch.tensor(0, dtype=loss.dtype, device=loss.device) for _ in range(dist.get_world_size())
+ ]
+ dist.all_gather(chosen_rewards_list, chosen_reward_mean)
+ rejected_reward_mean = rejected_rewards.mean()
+ rejected_rewards_list = [
+ torch.tensor(0, dtype=loss.dtype, device=loss.device) for _ in range(dist.get_world_size())
+ ]
+ dist.all_gather(rejected_rewards_list, rejected_reward_mean)
+ chosen_rewards_list = [i for i in chosen_rewards_list if not i.isnan()]
+ rejected_rewards_list = [i for i in rejected_rewards_list if not i.isnan()]
+ chosen_rewards_mean = (
+ torch.stack(chosen_rewards_list).mean()
+ if len(chosen_rewards_list) > 0
+ else torch.tensor(torch.nan, dtype=loss.dtype, device=loss.device)
+ )
+ rejected_rewards_mean = (
+ torch.stack(rejected_rewards_list).mean()
+ if len(rejected_rewards_list) > 0
+ else torch.tensor(torch.nan, dtype=loss.dtype, device=loss.device)
+ )
self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).detach().item())
@@ -256,6 +282,11 @@ class KTOTrainer(SLTrainer):
batch["kl_attention_mask"],
batch["kl_loss_mask"],
)
+
+ if not self.apply_loss_mask:
+ loss_mask = loss_mask.fill_(1.0)
+ kl_loss_mask = kl_loss_mask.fill_(1.0)
+
batch_size = input_ids.size()[0]
# actor logits
diff --git a/applications/ColossalChat/coati/trainer/orpo.py b/applications/ColossalChat/coati/trainer/orpo.py
index b039da4af..c2f75771c 100644
--- a/applications/ColossalChat/coati/trainer/orpo.py
+++ b/applications/ColossalChat/coati/trainer/orpo.py
@@ -52,6 +52,7 @@ class ORPOTrainer(SLTrainer):
tokenizer: PreTrainedTokenizerBase,
max_epochs: int = 1,
lam: float = 0.1,
+ apply_loss_mask: bool = True,
accumulation_steps: int = 1,
start_epoch: int = 0,
save_interval: int = 0,
@@ -67,6 +68,7 @@ class ORPOTrainer(SLTrainer):
self.save_dir = save_dir
self.num_train_step = 0
self.lam = lam
+ self.apply_loss_mask = apply_loss_mask
self.accumulation_steps = accumulation_steps
self.device = get_current_device()
self.accumulative_meter = AccumulativeMeanMeter()
@@ -130,6 +132,11 @@ class ORPOTrainer(SLTrainer):
batch["reject_attention_mask"],
batch["reject_loss_mask"],
)
+
+ if not self.apply_loss_mask:
+ chosen_loss_mask = chosen_loss_mask.fill_(1.0)
+ reject_loss_mask = reject_loss_mask.fill_(1.0)
+
batch_size = chosen_input_ids.size()[0]
actor_out = self.model(
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
@@ -263,6 +270,11 @@ class ORPOTrainer(SLTrainer):
batch["reject_attention_mask"],
batch["reject_loss_mask"],
)
+
+ if not self.apply_loss_mask:
+ chosen_loss_mask = chosen_loss_mask.fill_(1.0)
+ reject_loss_mask = reject_loss_mask.fill_(1.0)
+
batch_size = chosen_input_ids.size()[0]
actor_out = self.model(
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
diff --git a/applications/ColossalChat/coati/trainer/ppo.py b/applications/ColossalChat/coati/trainer/ppo.py
index 287767669..63c813b39 100755
--- a/applications/ColossalChat/coati/trainer/ppo.py
+++ b/applications/ColossalChat/coati/trainer/ppo.py
@@ -102,6 +102,7 @@ class PPOTrainer(OLTrainer):
sample_buffer: bool = False,
dataloader_pin_memory: bool = True,
offload_inference_models: bool = True,
+ apply_loss_mask: bool = True,
accumulation_steps: int = 1,
save_interval: int = 0,
save_dir: str = None,
@@ -140,6 +141,7 @@ class PPOTrainer(OLTrainer):
self.actor_optim = actor_optim
self.critic_optim = critic_optim
self.save_interval = save_interval
+ self.apply_loss_mask = apply_loss_mask
self.coordinator = coordinator
self.actor_save_dir = os.path.join(save_dir, "actor")
self.critic_save_dir = os.path.join(save_dir, "critic")
@@ -229,7 +231,10 @@ class PPOTrainer(OLTrainer):
action_log_probs = calc_action_log_probs(actor_logits, experience.sequences, num_actions)
actor_loss, to_skip, max_ratio = self.actor_loss_fn(
- action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
+ action_log_probs,
+ experience.action_log_probs,
+ experience.advantages,
+ action_mask=experience.action_mask if self.apply_loss_mask else None,
)
actor_loss = (1 - self.ptx_coef) * actor_loss
if not to_skip:
@@ -249,7 +254,10 @@ class PPOTrainer(OLTrainer):
input_ids=experience.sequences, attention_mask=experience.attention_mask
) # [batch size, prompt_length + response_length]
critic_loss = self.critic_loss_fn(
- values[:, -num_actions:], experience.values, experience.advantages, action_mask=experience.action_mask
+ values[:, -num_actions:],
+ experience.values,
+ experience.advantages,
+ action_mask=experience.action_mask if self.apply_loss_mask else None,
)
critic_loss = critic_loss * self.vf_coef
self.critic_booster.backward(loss=critic_loss, optimizer=self.critic_optim)
diff --git a/applications/ColossalChat/coati/trainer/sft.py b/applications/ColossalChat/coati/trainer/sft.py
index c09d61034..d37676ada 100755
--- a/applications/ColossalChat/coati/trainer/sft.py
+++ b/applications/ColossalChat/coati/trainer/sft.py
@@ -41,6 +41,7 @@ class SFTTrainer(SLTrainer):
lr_scheduler: _LRScheduler,
max_epochs: int = 2,
accumulation_steps: int = 8,
+ apply_loss_mask: bool = True,
start_epoch=0,
save_interval: int = None,
save_dir: str = None,
@@ -55,6 +56,7 @@ class SFTTrainer(SLTrainer):
self.coordinator = coordinator
self.num_train_step = 0
self.num_eval_step = 0
+ self.apply_loss_mask = apply_loss_mask
self.accumulative_meter = AccumulativeMeanMeter()
def _before_fit(
@@ -100,7 +102,11 @@ class SFTTrainer(SLTrainer):
for i, batch in enumerate(self.train_dataloader):
batch = to_device(batch, torch.cuda.current_device())
batch_size = batch["input_ids"].size(0)
- outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
+ outputs = self.model(
+ batch["input_ids"],
+ attention_mask=batch["attention_mask"],
+ labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"],
+ )
loss = outputs.loss
self.booster.backward(loss=loss, optimizer=self.optimizer)
@@ -158,7 +164,11 @@ class SFTTrainer(SLTrainer):
)
for batch in self.eval_dataloader:
batch = to_device(batch, torch.cuda.current_device())
- outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
+ outputs = self.model(
+ batch["input_ids"],
+ attention_mask=batch["attention_mask"],
+ labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"],
+ )
loss_mean = all_reduce_mean(tensor=outputs.loss)
self.accumulative_meter.add("loss", loss_mean.item(), count_update=batch["input_ids"].size(0))
step_bar.update()
diff --git a/applications/ColossalChat/examples/README.md b/applications/ColossalChat/examples/README.md
index b749f197e..904d69cfc 100755
--- a/applications/ColossalChat/examples/README.md
+++ b/applications/ColossalChat/examples/README.md
@@ -387,6 +387,7 @@ colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile trai
- save_dir: path to store the model checkpoints.
- max_length: input will be padded/truncated to max_length before feeding to the model.
- max_epochs: number of epochs to train.
+- disable_loss_mask: whether to use the loss mask to mask the loss or not. For example, in SFT, if the loss mask is disabled, the model will compute the loss across all tokens in the sequence, if the loss mask is applied, only tokens correspond to the assistant responses will contribute to the final loss.
- batch_size: training batch size.
- mixed_precision: precision to use in training. Support 'fp16' and 'bf16'. Note that some devices may not support the 'bf16' option, please refer to [Nvidia](https://developer.nvidia.com/) to check compatibility.
- save_interval: save the model weights as well as optimizer/scheduler states every save_interval steps/episodes.
diff --git a/applications/ColossalChat/examples/inference/inference.py b/applications/ColossalChat/examples/inference/inference.py
index 103bd8d95..5f59ba452 100755
--- a/applications/ColossalChat/examples/inference/inference.py
+++ b/applications/ColossalChat/examples/inference/inference.py
@@ -53,8 +53,8 @@ def load_model_and_tokenizer(model_path, tokenizer_path, device="cuda", **kwargs
tuple: A tuple containing the loaded model and tokenizer.
"""
- model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs)
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
+ model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs, trust_remote_code=True).to(torch.bfloat16)
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
model.to(device)
diff --git a/applications/ColossalChat/examples/inference/round.txt b/applications/ColossalChat/examples/inference/round.txt
deleted file mode 100644
index ba02074c1..000000000
--- a/applications/ColossalChat/examples/inference/round.txt
+++ /dev/null
@@ -1,104 +0,0 @@
-
-
-==========
-round 1:
-[INST] <>
-A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
-
-
-<>
-
-tell me a story [/INST] Great, let’s hear a story.
-
-==========
-
-
-==========
-round 2:
-[INST] <>
-A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
-
-
-<>
-
-tell me a story [/INST] Great, let’s hear a story. [INST] calculate 1+1 [/INST] 1+1=2
-
-==========
-
-
-==========
-round 3:
-[INST] <>
-A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
-
-
-<>
-
-tell me a story [/INST] Great, let’s hear a story. [INST] calculate 1+1 [/INST] 1+1=2 [INST] who is the first president of the USA [/INST] The first president of the United States was George Washington.
-
-==========
-
-
-==========
-round 1:
-[INST] <>
-A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
-
-
-<>
-
-who is the first president of the USA? [/INST] The first president of the United States was George Washington.
-
-==========
-
-
-==========
-round 2:
-[INST] <>
-A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
-
-
-<>
-
-who is the first president of the USA? [/INST] The first president of the United States was George Washington. [INST] tell me a story [/INST] One story that might be interesting is the story of how the United States was founded. In 1776, the Thirteen Colonies united together to form the new nation of America. The first president of the new nation was George Washington. The first state was Pennsylvania. The first capital was Philadelphia. The first presidential election was held in 1787, and was between George Read, a Federalist, and John Adams, a Republican. The Federalists wanted to keep the power of the federal government limited, while the Republicans wanted the power to be spread around to the states. The Federalists won, and George Washington was elected president.
-
-==========
-
-
-==========
-round 1:
-[INST] <>
-A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
-
-
-<>
-
-tell me a story [/INST] I do enjoy telling stories. Is there a particular story you’d like to hear?
-
-==========
-
-
-==========
-round 2:
-[INST] <>
-A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
-
-
-<>
-
-tell me a story [/INST] I do enjoy telling stories. Is there a particular story you’d like to hear? [INST] about Donald Trump [/INST] I’d be happy to listen to a story about Donald Trump. I have no idea what a “story” is, but I do know a lot about human life and human behavior. I can try. Do you have a theme or a question you’d like to ask about Donald Trump?
-
-==========
-
-
-==========
-round 3:
-[INST] <>
-A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
-
-
-<>
-
-tell me a story [/INST] I do enjoy telling stories. Is there a particular story you’d like to hear? [INST] about Donald Trump [/INST] I’d be happy to listen to a story about Donald Trump. I have no idea what a “story” is, but I do know a lot about human life and human behavior. I can try. Do you have a theme or a question you’d like to ask about Donald Trump? [INST] Is Donald Trump the president of the United States [/INST] Yes, Donald Trump became the 45th president of the United States in January of 2016.
-
-==========
diff --git a/applications/ColossalChat/examples/training_scripts/train_dpo.py b/applications/ColossalChat/examples/training_scripts/train_dpo.py
index 44131f572..d88750aeb 100755
--- a/applications/ColossalChat/examples/training_scripts/train_dpo.py
+++ b/applications/ColossalChat/examples/training_scripts/train_dpo.py
@@ -278,6 +278,7 @@ def train(args):
beta=args.beta,
gamma=args.gamma,
length_normalization=args.length_normalization,
+ apply_loss_mask=not args.disable_loss_mask,
)
trainer.fit(
@@ -346,6 +347,7 @@ if __name__ == "__main__":
default=False,
help="Disable the reference model (enabled by default)",
)
+ parser.add_argument("--disable_loss_mask", default=False, action="store_true")
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
diff --git a/applications/ColossalChat/examples/training_scripts/train_kto.py b/applications/ColossalChat/examples/training_scripts/train_kto.py
index d063b82bb..598fd8062 100755
--- a/applications/ColossalChat/examples/training_scripts/train_kto.py
+++ b/applications/ColossalChat/examples/training_scripts/train_kto.py
@@ -297,6 +297,7 @@ def train(args):
beta=args.beta,
desirable_weight=args.desirable_weight,
undesirable_weight=args.undesirable_weight,
+ apply_loss_mask=not args.disable_loss_mask,
)
trainer.fit(
@@ -341,6 +342,7 @@ if __name__ == "__main__":
parser.add_argument("--beta", type=float, default=0.1, help="beta in KTO loss")
parser.add_argument("--desirable_weight", type=float, default=1.0, help="desirable_weight in KTO loss")
parser.add_argument("--undesirable_weight", type=float, default=1.0, help="undesirable_weight in KTO loss")
+ parser.add_argument("--disable_loss_mask", default=False, action="store_true")
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
diff --git a/applications/ColossalChat/examples/training_scripts/train_orpo.py b/applications/ColossalChat/examples/training_scripts/train_orpo.py
index f06524507..87860f7ea 100755
--- a/applications/ColossalChat/examples/training_scripts/train_orpo.py
+++ b/applications/ColossalChat/examples/training_scripts/train_orpo.py
@@ -259,6 +259,7 @@ def train(args):
save_dir=args.save_dir,
coordinator=coordinator,
lam=args.lam,
+ apply_loss_mask=not args.disable_loss_mask,
)
trainer.fit(
@@ -301,6 +302,7 @@ if __name__ == "__main__":
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--sp", type=int, default=1)
parser.add_argument("--lam", type=float, default=0.1, help="lambda in ORPO loss")
+ parser.add_argument("--disable_loss_mask", default=False, action="store_true")
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
diff --git a/applications/ColossalChat/examples/training_scripts/train_ppo.py b/applications/ColossalChat/examples/training_scripts/train_ppo.py
index 333be9963..c10418394 100755
--- a/applications/ColossalChat/examples/training_scripts/train_ppo.py
+++ b/applications/ColossalChat/examples/training_scripts/train_ppo.py
@@ -411,6 +411,7 @@ def train(args):
use_cache=True,
do_sample=True,
temperature=0.7,
+ apply_loss_mask=not args.disable_loss_mask,
accumulation_steps=args.accumulation_steps,
save_dir=args.save_path,
save_interval=args.save_interval,
@@ -498,6 +499,7 @@ if __name__ == "__main__":
parser.add_argument("--critic_lr", type=float, default=9e-6)
parser.add_argument("--kl_coef", type=float, default=0.1)
parser.add_argument("--ptx_coef", type=float, default=0.0)
+ parser.add_argument("--disable_loss_mask", default=False, action="store_true")
parser.add_argument("--max_length", type=int, default=2048)
parser.add_argument("--max_seq_len", type=int, default=256)
parser.add_argument("--log_dir", default="logs", type=str)
diff --git a/applications/ColossalChat/examples/training_scripts/train_sft.py b/applications/ColossalChat/examples/training_scripts/train_sft.py
index 6007a8599..c4ef3b783 100755
--- a/applications/ColossalChat/examples/training_scripts/train_sft.py
+++ b/applications/ColossalChat/examples/training_scripts/train_sft.py
@@ -272,6 +272,7 @@ def train(args):
lr_scheduler=lr_scheduler,
max_epochs=args.max_epochs,
accumulation_steps=args.accumulation_steps,
+ apply_loss_mask=not args.disable_loss_mask,
start_epoch=start_epoch,
save_interval=args.save_interval,
save_dir=args.save_path,
@@ -317,6 +318,7 @@ if __name__ == "__main__":
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--sp", type=int, default=1)
+ parser.add_argument("--disable_loss_mask", default=False, action="store_true")
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")