mirror of https://github.com/hpcaitech/ColossalAI
[chatgpt] unnify datasets (#3218)
parent
4fd4bd9d9a
commit
fa97a9cab4
|
@ -54,7 +54,8 @@ class SFTDataset(Dataset):
|
|||
|
||||
def __init__(self, dataset, tokenizer: Callable, max_length: int=512) -> None:
|
||||
super().__init__()
|
||||
self.prompts = []
|
||||
# self.prompts = []
|
||||
self.input_ids = []
|
||||
|
||||
for data in tqdm(dataset, disable=not is_rank_0()):
|
||||
prompt = data['prompt'] + data['completion'] + "<|endoftext|>"
|
||||
|
@ -64,14 +65,18 @@ class SFTDataset(Dataset):
|
|||
truncation=True,
|
||||
return_tensors="pt")
|
||||
|
||||
self.prompts.append(prompt_token)
|
||||
# self.prompts.append(prompt_token)s
|
||||
self.input_ids.append(prompt_token)
|
||||
self.labels = copy.deepcopy(self.input_ids)
|
||||
|
||||
def __len__(self):
|
||||
length = len(self.prompts)
|
||||
return length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.prompts[idx]
|
||||
# dict(input_ids=self.input_ids[i], labels=self.labels[i])
|
||||
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
|
||||
# return dict(self.prompts[idx], self.prompts[idx])
|
||||
|
||||
|
||||
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
|
||||
|
|
|
@ -63,11 +63,13 @@ class SFTTrainer(ABC):
|
|||
for batch_id, batch in enumerate(self.train_dataloader):
|
||||
prompt_ids = batch["input_ids"]
|
||||
p_mask = batch["attention_mask"]
|
||||
labels = batch["labels"]
|
||||
prompt_ids = prompt_ids.squeeze(1).cuda()
|
||||
p_mask = p_mask.squeeze(1).cuda()
|
||||
prompt_logits = self.model(prompt_ids, attention_mask=p_mask)
|
||||
# prompt_logits = self.model(prompt_ids, attention_mask=p_mask, labels=labels)
|
||||
loss, prompt_logits = self.model(prompt_ids, attention_mask=p_mask, labels=labels)
|
||||
|
||||
loss = self.loss_fn(prompt_logits, prompt_ids)
|
||||
# loss = self.loss_fn(prompt_logits, labels)
|
||||
self.strategy.backward(loss, self.model, self.optimizer)
|
||||
self.strategy.optimizer_step(self.optimizer)
|
||||
self.optimizer.zero_grad()
|
||||
|
|
Loading…
Reference in New Issue