mirror of https://github.com/InternLM/InternLM
377 lines
15 KiB
Python
377 lines
15 KiB
Python
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
import itertools as it
|
|
import operator
|
|
import os
|
|
from copy import deepcopy
|
|
from typing import Dict
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch.utils.data import ConcatDataset
|
|
from tqdm import tqdm
|
|
|
|
from internlm.core.context import global_context as gpc
|
|
from internlm.data.single_dataset import JsonlDataset
|
|
from internlm.data.utils import get_dataset_type_id
|
|
from internlm.utils.logger import get_logger
|
|
|
|
DEFAULT_SEED = 1024
|
|
logger = get_logger(__file__)
|
|
|
|
|
|
class PackedDataset(torch.utils.data.Dataset):
|
|
"""
|
|
The class PackedDataset takes in a dataset and aggregates samples of different
|
|
lengths together based on the packed_length.
|
|
|
|
Args:
|
|
dataset: The original dataset to pack.
|
|
max_length_per_sample: The maximum length of each original sample. Default is 2048.
|
|
packed_length: The length of each packed sample. Default is 4096.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dataset,
|
|
max_length_per_sample: int = 2048,
|
|
packed_length: int = 4096,
|
|
):
|
|
assert hasattr(dataset, "lengths")
|
|
assert len(getattr(dataset, "lengths")) == len(
|
|
dataset
|
|
), "The dataset must have lengths attribute and have the same length as the dataset"
|
|
self.dataset = dataset
|
|
self.max_length_per_sample = max_length_per_sample
|
|
self.lengths = getattr(self.dataset, "lengths")
|
|
self.packed_length = packed_length
|
|
# Force a seed to be fixed to prevent problems caused by the seed not being restored when restarting
|
|
|
|
self.seed = DEFAULT_SEED
|
|
self.sample_indices, self.len_samples_shuffled, self.acm_len_samples = self.accu_sample_len(seed=self.seed)
|
|
self.num_tokens = sum(self.lengths)
|
|
|
|
def get_dataset_name(self):
|
|
return self.dataset.get_dataset_name()
|
|
|
|
def accu_sample_len(self, seed=None):
|
|
"""accumulative length of samples"""
|
|
if seed is not None:
|
|
rng = np.random.RandomState(seed)
|
|
else:
|
|
rng = np.random.RandomState(self.seed - 1)
|
|
|
|
sample_indices = np.arange(len(self.lengths))
|
|
rng.shuffle(sample_indices)
|
|
len_samples_shuffled = list(map(self.lengths.__getitem__, sample_indices))
|
|
acm_len_samples = list(it.accumulate(len_samples_shuffled, operator.add))
|
|
return sample_indices, len_samples_shuffled, acm_len_samples
|
|
|
|
def __len__(self):
|
|
# Line 405 of document_to_sequence.py in metaseq is directly spliced,
|
|
# without additional consideration of sos or eos
|
|
n_packs = self.num_tokens // self.packed_length
|
|
return n_packs
|
|
|
|
def cal_map(self, carriage_idx: int = 0):
|
|
assert carriage_idx >= 0
|
|
length_train = (carriage_idx + 1) * self.packed_length
|
|
post_pos = np.searchsorted(self.acm_len_samples, length_train, side="left")
|
|
return post_pos
|
|
|
|
def mapping(self, pack_idx: int = 0):
|
|
# pack_idx is zero-based
|
|
pre_pos, pre_token_id = 0, 0
|
|
if pack_idx > 0:
|
|
pre_pos = self.cal_map(pack_idx - 1)
|
|
pre_token_id = self.len_samples_shuffled[pre_pos] - (
|
|
self.acm_len_samples[pre_pos] - (pack_idx) * self.packed_length
|
|
)
|
|
if pre_token_id == self.len_samples_shuffled[pre_pos]:
|
|
pre_pos += 1
|
|
pre_token_id = 0
|
|
|
|
pos = self.cal_map(pack_idx)
|
|
token_id = self.len_samples_shuffled[pos] - (self.acm_len_samples[pos] - (pack_idx + 1) * self.packed_length)
|
|
return pre_pos, pre_token_id, pos, token_id
|
|
|
|
def build_pack(self, pre_pos: int, pre_token_id: int, pos: int, token_id: int):
|
|
pack, cu_seqlens, indexes, labels, type_ids = [], [0], [], [], []
|
|
|
|
while pre_pos < pos:
|
|
sample_idx = self.sample_indices[pre_pos]
|
|
sample = self.dataset[sample_idx]
|
|
chunk = sample["tokens"][pre_token_id:]
|
|
pack.extend(chunk)
|
|
_labels = deepcopy(chunk)
|
|
_labels = list(_labels[1:]) + [-100]
|
|
assert len(_labels) == len(chunk), (_labels, chunk)
|
|
labels.extend(_labels)
|
|
type_ids.extend([sample.get("type_id", 0)] * len(chunk))
|
|
num_new_samples, tokens_left = divmod(len(chunk), self.max_length_per_sample)
|
|
for _ in range(num_new_samples):
|
|
cu_seqlens.append(cu_seqlens[-1] + self.max_length_per_sample)
|
|
indexes.extend(list(range(self.max_length_per_sample)))
|
|
if tokens_left > 0:
|
|
cu_seqlens.append(cu_seqlens[-1] + tokens_left)
|
|
indexes.extend(list(range(tokens_left)))
|
|
pre_pos = pre_pos + 1
|
|
pre_token_id = 0
|
|
|
|
sample_idx = self.sample_indices[pos]
|
|
sample = self.dataset[sample_idx]
|
|
chunk = sample["tokens"][pre_token_id:token_id] # fragement of a sample
|
|
pack.extend(chunk)
|
|
_labels = deepcopy(chunk)
|
|
if token_id == len(sample["tokens"]):
|
|
_labels = list(_labels[1:]) + [-100]
|
|
else:
|
|
if token_id > len(sample["tokens"]):
|
|
print(f"token_id {token_id}, len of sample {len(sample['tokens'])}")
|
|
_labels = list(_labels[1:]) + [sample["tokens"][token_id]]
|
|
assert len(_labels) == len(chunk), (_labels, chunk)
|
|
labels.extend(_labels)
|
|
type_ids.extend([sample.get("type_id", 0)] * len(chunk))
|
|
num_new_samples, tokens_left = divmod(len(chunk), self.max_length_per_sample)
|
|
for _ in range(num_new_samples):
|
|
cu_seqlens.append(cu_seqlens[-1] + self.max_length_per_sample)
|
|
indexes.extend(list(range(self.max_length_per_sample)))
|
|
if tokens_left > 0:
|
|
cu_seqlens.append(cu_seqlens[-1] + tokens_left)
|
|
indexes.extend(list(range(tokens_left)))
|
|
|
|
out = {"tokens": pack, "cu_seqlens": cu_seqlens, "indexes": indexes, "labels": labels, "type_ids": type_ids}
|
|
return out
|
|
|
|
def __getitem__(self, item: int) -> Dict:
|
|
"""Given the index, it returns a dict as
|
|
{
|
|
'tokens': List[int],
|
|
'cu_seqlens': List[int],
|
|
'indexes': List[int], # denotes positional vector as 'tokens'
|
|
'labels': List[int], # corresponds to 'tokens' and shifted yet, -100 means skipping prediction
|
|
}
|
|
"""
|
|
|
|
pos_before, token_id_before, pos_after, token_id_after = self.mapping(item)
|
|
return self.build_pack(pos_before, token_id_before, pos_after, token_id_after)
|
|
|
|
|
|
class PackedDatasetWithoutCuSeqlen(torch.utils.data.Dataset):
|
|
"""
|
|
A dataset wrapper that aggregates samples with different lengths based on packed_length.
|
|
If a sample is shorter than max_length_per_sample, it will be merged with other samples.
|
|
For example, given a dataset with 10 samples:
|
|
[1, 2, 3, 4, 5]
|
|
[6, 7]
|
|
[8, 9, 10, 11]
|
|
[12, ..., 100]
|
|
...
|
|
|
|
Args:
|
|
dataset: The original dataset to be wrapped.
|
|
max_length_per_sample (int): The maximum length allowed for each sample.
|
|
packed_length (int): The desired length for each packed sample.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dataset,
|
|
max_length_per_sample: int = 2048,
|
|
packed_length: int = 4096,
|
|
debug=False,
|
|
):
|
|
assert packed_length % max_length_per_sample == 0
|
|
assert hasattr(dataset, "lengths")
|
|
assert len(getattr(dataset, "lengths")) == len(
|
|
dataset
|
|
), "The dataset must have lengths attribute and have the same length as the dataset"
|
|
self.dataset = dataset
|
|
self.max_length_per_sample = max_length_per_sample
|
|
self.lengths = getattr(self.dataset, "lengths")
|
|
self.bsz = packed_length // max_length_per_sample
|
|
self.packed_length = packed_length
|
|
self.debug = debug
|
|
# Force a seed to be fixed to prevent problems caused by the seed not being restored when restarting
|
|
|
|
self.seed = DEFAULT_SEED
|
|
indices = np.arange(len(self.lengths))
|
|
rng = np.random.RandomState(self.seed)
|
|
rng.shuffle(indices)
|
|
self.indices = indices
|
|
self.cum_lens = np.cumsum(self.lengths[self.indices])
|
|
self.num_tokens = sum(self.lengths)
|
|
|
|
def get_dataset_name(self):
|
|
return self.dataset.get_dataset_name()
|
|
|
|
def __len__(self):
|
|
n_packs = self.num_tokens // self.packed_length
|
|
return n_packs
|
|
|
|
def find_offset(self, offset):
|
|
idx = np.searchsorted(self.cum_lens, offset, side="right")
|
|
if idx == 0:
|
|
return idx, offset
|
|
length = offset - self.cum_lens[idx - 1]
|
|
return idx, length
|
|
|
|
def pdebug(self, line):
|
|
if self.debug:
|
|
print(line, flush=True)
|
|
|
|
def __getitem__(self, item: int) -> Dict:
|
|
"""Given the index, it returns a dict as
|
|
{
|
|
'tokens': List[int],
|
|
'cu_seqlens': List[int],
|
|
'indexes': List[int], # denotes positional vector as 'tokens'
|
|
'labels': List[int], # corresponds to 'tokens' and shifted yet, -100 means skipping prediction
|
|
}
|
|
"""
|
|
|
|
start_idx, start_length = self.find_offset(item * self.packed_length)
|
|
end_idx, end_length = self.find_offset((item + 1) * self.packed_length)
|
|
pack_tokens = []
|
|
pack_labels = []
|
|
type_ids = []
|
|
|
|
self.pdebug(f"item : {item}, start_idx:{start_idx}, start_length:{start_length} ")
|
|
self.pdebug(f"item : {item}, end_idx:{end_idx}, end_length:{end_length} ")
|
|
|
|
if start_idx == end_idx:
|
|
idx = self.indices[start_idx]
|
|
sample = self.dataset[idx]
|
|
self.pdebug(f"item : {item}, idx: {idx}, len : {len(sample['tokens'])}")
|
|
tokens = sample["tokens"][start_length:end_length]
|
|
pack_tokens.extend(tokens)
|
|
pack_labels.extend(tokens[1:] + [-100])
|
|
type_ids.extend([sample["type_id"]] * len(tokens))
|
|
return {
|
|
"tokens": pack_tokens,
|
|
"cu_seqlens": [i * self.max_length_per_sample for i in range(self.bsz + 1)],
|
|
"indexes": list(range(self.max_length_per_sample)) * self.bsz,
|
|
"labels": pack_labels,
|
|
"type_ids": type_ids,
|
|
}
|
|
|
|
idx = self.indices[start_idx]
|
|
sample = self.dataset[idx]
|
|
self.pdebug(f"item : {item}, idx: {idx}, len : {len(sample['tokens'])}")
|
|
tokens = sample["tokens"][start_length:]
|
|
pack_tokens.extend(tokens)
|
|
pack_labels.extend(tokens[1:] + [-100])
|
|
type_ids.extend([sample["type_id"]] * len(tokens))
|
|
|
|
for i in range(start_idx + 1, end_idx):
|
|
idx = self.indices[i]
|
|
sample = self.dataset[idx]
|
|
self.pdebug(f"item : {item}, idx: {idx}, len : {len(sample['tokens'])}")
|
|
tokens = sample["tokens"]
|
|
pack_tokens.extend(tokens)
|
|
pack_labels.extend(tokens[1:] + [-100])
|
|
type_ids.extend([sample.get("type_id")] * len(tokens))
|
|
|
|
# corner case, the last sample is useless
|
|
if end_length == 0:
|
|
pass
|
|
else:
|
|
idx = self.indices[end_idx]
|
|
sample = self.dataset[idx]
|
|
self.pdebug(f"item : {item}, idx: {idx}, len : {len(sample['tokens'])}")
|
|
tokens = sample["tokens"][:end_length]
|
|
pack_tokens.extend(tokens)
|
|
pack_labels.extend(tokens[1:] + [-100])
|
|
type_ids.extend([sample.get("type_id")] * len(tokens))
|
|
|
|
return {
|
|
"tokens": pack_tokens,
|
|
"cu_seqlens": [i * self.max_length_per_sample for i in range(self.bsz + 1)],
|
|
"indexes": list(range(self.max_length_per_sample)) * self.bsz,
|
|
"labels": pack_labels,
|
|
"type_ids": type_ids,
|
|
}
|
|
|
|
|
|
def get_packed_dataset_without_short_length(
|
|
folder,
|
|
max_length_per_sample=2048,
|
|
packed_length=4096,
|
|
show_progress=False,
|
|
min_length=50,
|
|
min_length_dict=None,
|
|
pack_into_one_sample=False,
|
|
):
|
|
"""
|
|
Given a folder, combine all the .bin files into a single large dataset.
|
|
And filter out short samples with length less than 'min_length'.
|
|
|
|
Each .bin file is treated as a separate dataset.
|
|
|
|
Args:
|
|
folder (str): Path to the folder containing the .bin files.
|
|
max_length_per_sample (int): Maximum length of each sample.
|
|
packed_length (int): Length to pack samples to.
|
|
show_progress (bool): Whether to show the progress bar.
|
|
min_length (int): The minimum length of the sample.
|
|
min_length_dict (dict): The minimum length of the sample for each dataset.
|
|
The format is something like {'pile-arxiv': 50}
|
|
dataset_backend (Optional[str]): Dataset storage location. Optional parameters are local, local-shm, kv
|
|
|
|
Returns:
|
|
A packed dataset containing all the data from the .bin files.
|
|
"""
|
|
|
|
assert os.path.exists(folder), f"{folder} does not exist."
|
|
datasets = []
|
|
delete_samples = 0
|
|
|
|
for root, dirs, files in os.walk(folder, followlinks=True):
|
|
dirs.sort() # Let the folder need to be returned in a fixed order
|
|
if gpc.is_rank_for_log():
|
|
logger.info(f"Reading {root}...")
|
|
num_token_in_folder = 0
|
|
|
|
for fn in tqdm(sorted(files), total=len(files), leave=False, disable=not show_progress):
|
|
if fn.endswith(".bin"):
|
|
fp = os.path.join(root, fn)
|
|
catch_ml_keys = []
|
|
min_length_num = min_length
|
|
if min_length_dict is not None:
|
|
for k, v in min_length_dict.items():
|
|
if k in fp:
|
|
min_length_num = v
|
|
catch_ml_keys.append(k)
|
|
assert (
|
|
len(catch_ml_keys) < 2
|
|
), f"The file name `{fp}` matched the following resample keys:{catch_ml_keys}"
|
|
|
|
ds_type_id = get_dataset_type_id(path=fp)
|
|
ds = JsonlDataset(fp, ds_type_id, min_length=min_length_num)
|
|
|
|
if hasattr(ds, "old_length"):
|
|
delete_samples += ds.old_length - len(ds)
|
|
if len(ds) == 0:
|
|
if gpc.is_rank_for_log():
|
|
logger.info(f"None of the data in `{fp}` is longer than {min_length}")
|
|
continue
|
|
|
|
if pack_into_one_sample:
|
|
ds = PackedDatasetWithoutCuSeqlen(ds, max_length_per_sample, packed_length)
|
|
else:
|
|
ds = PackedDataset(ds, max_length_per_sample, packed_length)
|
|
|
|
num_token_in_folder += len(ds) * packed_length
|
|
datasets.append(ds)
|
|
|
|
dataset = ConcatDataset(datasets=datasets)
|
|
if gpc.is_rank_for_log():
|
|
logger.info(
|
|
f"Find `{len(datasets)}` datasets, \
|
|
{len(dataset)} samples, \
|
|
delete `{delete_samples}` because of short length",
|
|
)
|
|
|
|
return dataset
|