From c5239840e62f8231236adaa97509b6ab18279fd6 Mon Sep 17 00:00:00 2001 From: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Date: Thu, 1 Feb 2024 14:25:16 +0800 Subject: [PATCH] [Chat] fix sft loss nan (#5345) * fix script * fix script * fix chat nan * fix chat nan --- applications/Chat/coati/dataset/sft_dataset.py | 10 ++++++---- applications/Chat/examples/train_sft.sh | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/applications/Chat/coati/dataset/sft_dataset.py b/applications/Chat/coati/dataset/sft_dataset.py index c0e257f54..e67e16231 100644 --- a/applications/Chat/coati/dataset/sft_dataset.py +++ b/applications/Chat/coati/dataset/sft_dataset.py @@ -49,12 +49,13 @@ def _preprocess( max_length: int, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Preprocess the data by tokenizing.""" - sequences = [s + t for s, t in zip(sources, targets)] + sequences = [s + t + tokenizer.eos_token for s, t in zip(sources, targets)] sequences_token = tokenizer( - sequences, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" + sequences, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt", add_special_tokens=False ) + sources_token = tokenizer( - sources, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" + sources, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt", add_special_tokens=False ) assert sequences_token["attention_mask"].dim() == 2, "seq2seq model should be preprocessed differently" @@ -65,7 +66,8 @@ def _preprocess( if tokenizer.padding_side == "right": # |prompt|completion|eos|pad| labels[i][:source_len] = IGNORE_INDEX - labels[i][-pad_len:] = IGNORE_INDEX + if pad_len>0: + labels[i][-pad_len:] = IGNORE_INDEX elif tokenizer.padding_side == "left": # |pad|prompt|completion|eos| labels[i][: pad_len + source_len] = IGNORE_INDEX diff --git a/applications/Chat/examples/train_sft.sh b/applications/Chat/examples/train_sft.sh index 0fb4da3d3..b7d176847 100755 --- a/applications/Chat/examples/train_sft.sh +++ b/applications/Chat/examples/train_sft.sh @@ -25,4 +25,4 @@ torchrun --standalone --nproc_per_node=4 train_sft.py \ --accumulation_steps 8 \ --lr 2e-5 \ --max_datasets_size 512 \ - --max_epochs 1 + --max_epochs 1 \ No newline at end of file