mirror of https://github.com/InternLM/InternLM
111 lines
3.8 KiB
Python
111 lines
3.8 KiB
Python
import copy
|
|
import os
|
|
|
|
import torch
|
|
from datasets import Dataset as HFDataset
|
|
from datasets import load_dataset
|
|
from torch.utils.data import Dataset
|
|
|
|
|
|
class SFTDataset(Dataset):
|
|
# https://github.com/OpenLMLab/MOSS/blob/main/finetune_moss.py
|
|
def __init__(self, dataset):
|
|
super().__init__()
|
|
self.dataset = dataset
|
|
|
|
def __len__(self):
|
|
return len(self.dataset)
|
|
|
|
def __getitem__(self, index):
|
|
data = copy.deepcopy(self.dataset[index]["input_ids"])
|
|
no_loss_spans = copy.deepcopy(self.dataset[index]["no_loss_spans"])
|
|
|
|
data = torch.tensor(data, dtype=torch.long)
|
|
label = copy.deepcopy(data)
|
|
|
|
for no_loss_span in no_loss_spans:
|
|
label[no_loss_span[0] : no_loss_span[1]] = -100
|
|
|
|
return data, label
|
|
|
|
|
|
def collate_fn(batch, tokenizer):
|
|
batch_input_ids, batch_labels = [], []
|
|
for input_ids, label in batch:
|
|
batch_input_ids.append(input_ids)
|
|
batch_labels.append(label)
|
|
|
|
batch_input_ids = torch.nn.utils.rnn.pad_sequence(
|
|
batch_input_ids, batch_first=True, padding_value=tokenizer.eos_token_id
|
|
)
|
|
batch_labels = torch.nn.utils.rnn.pad_sequence(batch_labels, batch_first=True, padding_value=-100)
|
|
|
|
return {
|
|
"input_ids": batch_input_ids,
|
|
"attention_mask": (batch_input_ids == tokenizer.eos_token_id).long(),
|
|
"labels": batch_labels,
|
|
}
|
|
|
|
|
|
def process(sample, tokenizer, max_len):
|
|
chat = sample["plain_text"].split("<eoa>")[:-1]
|
|
num_turns = sample["num_turns"]
|
|
meta_instruction = sample["prefix"]
|
|
|
|
# encode instruction
|
|
instruction_ids = tokenizer.encode(meta_instruction)
|
|
assert isinstance(instruction_ids, list), instruction_ids
|
|
assert len(instruction_ids) > 0, len(instruction_ids)
|
|
input_ids = copy.deepcopy(instruction_ids)
|
|
# We do not calculate loss for instruction.
|
|
no_loss_spans = [(0, len(instruction_ids))]
|
|
|
|
for i in range(num_turns):
|
|
# Collect dialogues
|
|
cur_turn_ids = []
|
|
cur_no_loss_spans = []
|
|
# Add to cur_turn_ids
|
|
cur_turn_ids.extend(tokenizer.encode(chat[i] + "<eoa>"))
|
|
# if key == 'Tool Responses':
|
|
# # The format tokens (<|Results|>:...<eor>\n) should have losses.
|
|
# cur_no_loss_spans.append((len(input_ids + cur_turn_ids) + 5, len(input_ids + cur_turn_ids + cur_ids) - 2))
|
|
if len(input_ids + cur_turn_ids) > max_len:
|
|
# Too long, break
|
|
break
|
|
# Extend input_ids
|
|
input_ids.extend(cur_turn_ids)
|
|
no_loss_spans.extend(cur_no_loss_spans)
|
|
|
|
if len(input_ids) == len(instruction_ids):
|
|
# No dialogue, return
|
|
return {"input_ids": [], "no_loss_spans": []}
|
|
else:
|
|
return {"input_ids": input_ids, "no_loss_spans": no_loss_spans}
|
|
|
|
|
|
def load_data(save_dir, tokenizer, max_len, num=-1) -> HFDataset:
|
|
if os.path.exists(save_dir):
|
|
print(f"Loading moss-002-sft from {save_dir}")
|
|
else:
|
|
print("Loading moss-002-sft from datasets")
|
|
moss_sft = load_dataset("fnlp/moss-002-sft-data", split="train")
|
|
moss_sft = moss_sft.map(lambda x: process(x, tokenizer, max_len), num_proc=10)
|
|
moss_sft = moss_sft.filter(lambda x: len(x["input_ids"]) != 0)
|
|
moss_sft.save_to_disk(save_dir)
|
|
|
|
moss_sft = HFDataset.load_from_disk(save_dir)
|
|
if num != -1:
|
|
moss_sft = moss_sft.select(range(num))
|
|
print(f"Load successfully, total {len(moss_sft)} samples.")
|
|
|
|
return moss_sft
|
|
|
|
|
|
def get_dataset(tokenizer, save_dir, max_len=1024, num=-1, test_size=0.1):
|
|
moss_sft_data = load_data(save_dir, tokenizer, max_len, num)
|
|
moss_sft_split = moss_sft_data.train_test_split(test_size=test_size)
|
|
train_dataset = SFTDataset(moss_sft_split["train"])
|
|
val_dataset = SFTDataset(moss_sft_split["test"])
|
|
|
|
return train_dataset, val_dataset
|