diff --git a/applications/ColossalChat/coati/dataset/loader.py b/applications/ColossalChat/coati/dataset/loader.py index f701cfdf9..7f43a45b6 100755 --- a/applications/ColossalChat/coati/dataset/loader.py +++ b/applications/ColossalChat/coati/dataset/loader.py @@ -207,23 +207,6 @@ class DataCollatorForPreferenceDataset(object): chuncate_sequence([ins["rejected_loss_mask"] for ins in instances], self.max_length, torch.bool), ) - for i in range(len(chosen_loss_mask)): - if sum(chosen_loss_mask[i][1:]) == 0: - print( - "After truncated", - chosen_loss_mask[i], - len(chosen_loss_mask[i]), - len(instances[i]["chosen_input_ids"]), - ) - for i in range(len(reject_loss_mask)): - if sum(reject_loss_mask[i][1:]) == 0: - print( - "After truncated", - reject_loss_mask[i], - len(reject_loss_mask[i]), - len(instances[i]["rejected_input_ids"]), - ) - padding_side = self.tokenizer.padding_side chosen_attention_mask = [torch.ones_like(seq).bool() for seq in chosen_input_ids] reject_attention_mask = [torch.ones_like(seq).bool() for seq in reject_input_ids] diff --git a/applications/ColossalChat/examples/training_scripts/train_sft.sh b/applications/ColossalChat/examples/training_scripts/train_sft.sh index d5ba6261e..0f6e09f6f 100755 --- a/applications/ColossalChat/examples/training_scripts/train_sft.sh +++ b/applications/ColossalChat/examples/training_scripts/train_sft.sh @@ -23,16 +23,16 @@ PARENT_CONFIG_FILE="" # Path to a folder to save training config logs PRETRAINED_MODEL_PATH="" # huggingface or local model path PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path declare -a dataset=( - /Your/Preference/Data/arrow/part-00000 - /Your/Preference/Data/arrow/part-00001 - /Your/Preference/Data/arrow/part-00002 - /Your/Preference/Data/arrow/part-00003 - /Your/Preference/Data/arrow/part-00004 - /Your/Preference/Data/arrow/part-00005 - /Your/Preference/Data/arrow/part-00006 - /Your/Preference/Data/arrow/part-00007 - /Your/Preference/Data/arrow/part-00008 - /Your/Preference/Data/arrow/part-00009 + /Your/SFT/Data/arrow/part-00000 + /Your/SFT/Data/arrow/part-00001 + /Your/SFT/Data/arrow/part-00002 + /Your/SFT/Data/arrow/part-00003 + /Your/SFT/Data/arrow/part-00004 + /Your/SFT/Data/arrow/part-00005 + /Your/SFT/Data/arrow/part-00006 + /Your/SFT/Data/arrow/part-00007 + /Your/SFT/Data/arrow/part-00008 + /Your/SFT/Data/arrow/part-00009 ) TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)