From 7788e0b0a5dcceab94fa5e98d25730003ab726a8 Mon Sep 17 00:00:00 2001 From: tingfeng cao <982912719@qq.com> Date: Mon, 17 Apr 2023 16:47:44 +0800 Subject: [PATCH] fix: fix sft (#3568) --- applications/Chat/coati/dataset/sft_dataset.py | 12 ++++-------- applications/Chat/coati/trainer/sft.py | 15 ++++++++------- applications/Chat/examples/train_sft.py | 2 +- 3 files changed, 13 insertions(+), 16 deletions(-) diff --git a/applications/Chat/coati/dataset/sft_dataset.py b/applications/Chat/coati/dataset/sft_dataset.py index 91e38f06d..3e2453468 100644 --- a/applications/Chat/coati/dataset/sft_dataset.py +++ b/applications/Chat/coati/dataset/sft_dataset.py @@ -53,29 +53,25 @@ class SFTDataset(Dataset): def __init__(self, dataset, tokenizer: Callable, max_length: int = 512) -> None: super().__init__() - # self.prompts = [] self.input_ids = [] for data in tqdm(dataset, disable=not is_rank_0()): - prompt = data['prompt'] + data['completion'] + "<|endoftext|>" + prompt = data['prompt'] + data['completion'] + tokenizer.eos_token prompt_token = tokenizer(prompt, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt") - # self.prompts.append(prompt_token)s - self.input_ids.append(prompt_token) - self.labels = copy.deepcopy(self.input_ids) + self.input_ids.append(prompt_token['input_ids'][0]) + self.labels = copy.deepcopy(self.input_ids) def __len__(self): - length = len(self.prompts) + length = len(self.input_ids) return length def __getitem__(self, idx): - # dict(input_ids=self.input_ids[i], labels=self.labels[i]) return dict(input_ids=self.input_ids[idx], labels=self.labels[idx]) - # return dict(self.prompts[idx], self.prompts[idx]) def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, max_length: int) -> Dict: diff --git a/applications/Chat/coati/trainer/sft.py b/applications/Chat/coati/trainer/sft.py index 8eeffea48..f380cbf06 100644 --- a/applications/Chat/coati/trainer/sft.py +++ b/applications/Chat/coati/trainer/sft.py @@ -96,7 +96,7 @@ class SFTTrainer(ABC): loss = outputs.loss prompt_logits = outputs.logits - if loss >= 2.5: + if loss >= 2.5 and is_rank_0(): logger.warning(f"batch_id:{batch_id}, abnormal loss: {loss}") loss = loss / self.accimulation_steps @@ -110,12 +110,13 @@ class SFTTrainer(ABC): self.strategy.optimizer_step(self.optimizer) self.optimizer.zero_grad() self.scheduler.step() - wandb.log({ - "loss": total_loss / self.accimulation_steps, - "lr": self.scheduler.get_last_lr()[0], - "epoch": epoch, - "batch_id": batch_id - }) + if is_rank_0(): + wandb.log({ + "loss": total_loss / self.accimulation_steps, + "lr": self.scheduler.get_last_lr()[0], + "epoch": epoch, + "batch_id": batch_id + }) total_loss = 0 step_bar.update() diff --git a/applications/Chat/examples/train_sft.py b/applications/Chat/examples/train_sft.py index 22f70e485..d7502c23b 100644 --- a/applications/Chat/examples/train_sft.py +++ b/applications/Chat/examples/train_sft.py @@ -111,7 +111,7 @@ def train(args): max_datasets_size=args.max_datasets_size, max_length=max_len) eval_dataset = None - data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) + data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) if dist.is_initialized() and dist.get_world_size() > 1: train_sampler = DistributedSampler(train_dataset,