[Chat] fix sft loss nan (#5345)

* fix script

* fix script

* fix chat nan

* fix chat nan
pull/5355/head
YeAnbang 2024-02-01 14:25:16 +08:00 committed by GitHub
parent abd8e77ad8
commit c5239840e6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 5 deletions

View File

@ -49,12 +49,13 @@ def _preprocess(
max_length: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Preprocess the data by tokenizing."""
sequences = [s + t for s, t in zip(sources, targets)]
sequences = [s + t + tokenizer.eos_token for s, t in zip(sources, targets)]
sequences_token = tokenizer(
sequences, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
sequences, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt", add_special_tokens=False
)
sources_token = tokenizer(
sources, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
sources, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt", add_special_tokens=False
)
assert sequences_token["attention_mask"].dim() == 2, "seq2seq model should be preprocessed differently"
@ -65,7 +66,8 @@ def _preprocess(
if tokenizer.padding_side == "right":
# |prompt|completion|eos|pad|
labels[i][:source_len] = IGNORE_INDEX
labels[i][-pad_len:] = IGNORE_INDEX
if pad_len>0:
labels[i][-pad_len:] = IGNORE_INDEX
elif tokenizer.padding_side == "left":
# |pad|prompt|completion|eos|
labels[i][: pad_len + source_len] = IGNORE_INDEX

View File

@ -25,4 +25,4 @@ torchrun --standalone --nproc_per_node=4 train_sft.py \
--accumulation_steps 8 \
--lr 2e-5 \
--max_datasets_size 512 \
--max_epochs 1
--max_epochs 1