diff --git a/applications/ChatGPT/chatgpt/dataset/sft_dataset.py b/applications/ChatGPT/chatgpt/dataset/sft_dataset.py index 67e1b761c..11ec61908 100644 --- a/applications/ChatGPT/chatgpt/dataset/sft_dataset.py +++ b/applications/ChatGPT/chatgpt/dataset/sft_dataset.py @@ -54,7 +54,8 @@ class SFTDataset(Dataset): def __init__(self, dataset, tokenizer: Callable, max_length: int=512) -> None: super().__init__() - self.prompts = [] + # self.prompts = [] + self.input_ids = [] for data in tqdm(dataset, disable=not is_rank_0()): prompt = data['prompt'] + data['completion'] + "<|endoftext|>" @@ -64,14 +65,18 @@ class SFTDataset(Dataset): truncation=True, return_tensors="pt") - self.prompts.append(prompt_token) + # self.prompts.append(prompt_token)s + self.input_ids.append(prompt_token) + self.labels = copy.deepcopy(self.input_ids) def __len__(self): length = len(self.prompts) return length def __getitem__(self, idx): - return self.prompts[idx] + # dict(input_ids=self.input_ids[i], labels=self.labels[i]) + return dict(input_ids=self.input_ids[i], labels=self.labels[i]) + # return dict(self.prompts[idx], self.prompts[idx]) def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict: diff --git a/applications/ChatGPT/chatgpt/trainer/sft.py b/applications/ChatGPT/chatgpt/trainer/sft.py index dd5cd35f5..3b35f5168 100644 --- a/applications/ChatGPT/chatgpt/trainer/sft.py +++ b/applications/ChatGPT/chatgpt/trainer/sft.py @@ -63,11 +63,13 @@ class SFTTrainer(ABC): for batch_id, batch in enumerate(self.train_dataloader): prompt_ids = batch["input_ids"] p_mask = batch["attention_mask"] + labels = batch["labels"] prompt_ids = prompt_ids.squeeze(1).cuda() p_mask = p_mask.squeeze(1).cuda() - prompt_logits = self.model(prompt_ids, attention_mask=p_mask) + # prompt_logits = self.model(prompt_ids, attention_mask=p_mask, labels=labels) + loss, prompt_logits = self.model(prompt_ids, attention_mask=p_mask, labels=labels) - loss = self.loss_fn(prompt_logits, prompt_ids) + # loss = self.loss_fn(prompt_logits, labels) self.strategy.backward(loss, self.model, self.optimizer) self.strategy.optimizer_step(self.optimizer) self.optimizer.zero_grad()