InternLM/internlm/data/packed_dataset.py

377 lines
15 KiB
Python
Raw Normal View History

2023-07-06 04:55:23 +00:00
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import itertools as it
import operator
import os
from copy import deepcopy
from typing import Dict
import numpy as np
import torch
from torch.utils.data import ConcatDataset
from tqdm import tqdm
from internlm.core.context import global_context as gpc
from internlm.data.single_dataset import JsonlDataset
from internlm.data.utils import get_dataset_type_id
from internlm.utils.logger import get_logger
DEFAULT_SEED = 1024
logger = get_logger(__file__)
class PackedDataset(torch.utils.data.Dataset):
"""
The class PackedDataset takes in a dataset and aggregates samples of different
lengths together based on the packed_length.
Args:
dataset: The original dataset to pack.
max_length_per_sample: The maximum length of each original sample. Default is 2048.
packed_length: The length of each packed sample. Default is 4096.
"""
def __init__(
self,
dataset,
max_length_per_sample: int = 2048,
packed_length: int = 4096,
):
assert hasattr(dataset, "lengths")
assert len(getattr(dataset, "lengths")) == len(
dataset
), "The dataset must have lengths attribute and have the same length as the dataset"
self.dataset = dataset
self.max_length_per_sample = max_length_per_sample
self.lengths = getattr(self.dataset, "lengths")
self.packed_length = packed_length
# Force a seed to be fixed to prevent problems caused by the seed not being restored when restarting
self.seed = DEFAULT_SEED
self.sample_indices, self.len_samples_shuffled, self.acm_len_samples = self.accu_sample_len(seed=self.seed)
self.num_tokens = sum(self.lengths)
def get_dataset_name(self):
return self.dataset.get_dataset_name()
def accu_sample_len(self, seed=None):
"""accumulative length of samples"""
if seed is not None:
rng = np.random.RandomState(seed)
else:
rng = np.random.RandomState(self.seed - 1)
sample_indices = np.arange(len(self.lengths))
rng.shuffle(sample_indices)
len_samples_shuffled = list(map(self.lengths.__getitem__, sample_indices))
acm_len_samples = list(it.accumulate(len_samples_shuffled, operator.add))
return sample_indices, len_samples_shuffled, acm_len_samples
def __len__(self):
# Line 405 of document_to_sequence.py in metaseq is directly spliced,
# without additional consideration of sos or eos
n_packs = self.num_tokens // self.packed_length
return n_packs
def cal_map(self, carriage_idx: int = 0):
assert carriage_idx >= 0
length_train = (carriage_idx + 1) * self.packed_length
post_pos = np.searchsorted(self.acm_len_samples, length_train, side="left")
return post_pos
def mapping(self, pack_idx: int = 0):
# pack_idx is zero-based
pre_pos, pre_token_id = 0, 0
if pack_idx > 0:
pre_pos = self.cal_map(pack_idx - 1)
pre_token_id = self.len_samples_shuffled[pre_pos] - (
self.acm_len_samples[pre_pos] - (pack_idx) * self.packed_length
)
if pre_token_id == self.len_samples_shuffled[pre_pos]:
pre_pos += 1
pre_token_id = 0
pos = self.cal_map(pack_idx)
token_id = self.len_samples_shuffled[pos] - (self.acm_len_samples[pos] - (pack_idx + 1) * self.packed_length)
return pre_pos, pre_token_id, pos, token_id
def build_pack(self, pre_pos: int, pre_token_id: int, pos: int, token_id: int):
pack, cu_seqlens, indexes, labels, type_ids = [], [0], [], [], []
while pre_pos < pos:
sample_idx = self.sample_indices[pre_pos]
sample = self.dataset[sample_idx]
chunk = sample["tokens"][pre_token_id:]
pack.extend(chunk)
_labels = deepcopy(chunk)
_labels = list(_labels[1:]) + [-100]
assert len(_labels) == len(chunk), (_labels, chunk)
labels.extend(_labels)
type_ids.extend([sample.get("type_id", 0)] * len(chunk))
num_new_samples, tokens_left = divmod(len(chunk), self.max_length_per_sample)
for _ in range(num_new_samples):
cu_seqlens.append(cu_seqlens[-1] + self.max_length_per_sample)
indexes.extend(list(range(self.max_length_per_sample)))
if tokens_left > 0:
cu_seqlens.append(cu_seqlens[-1] + tokens_left)
indexes.extend(list(range(tokens_left)))
pre_pos = pre_pos + 1
pre_token_id = 0
sample_idx = self.sample_indices[pos]
sample = self.dataset[sample_idx]
chunk = sample["tokens"][pre_token_id:token_id] # fragement of a sample
pack.extend(chunk)
_labels = deepcopy(chunk)
if token_id == len(sample["tokens"]):
_labels = list(_labels[1:]) + [-100]
else:
if token_id > len(sample["tokens"]):
print(f"token_id {token_id}, len of sample {len(sample['tokens'])}")
_labels = list(_labels[1:]) + [sample["tokens"][token_id]]
assert len(_labels) == len(chunk), (_labels, chunk)
labels.extend(_labels)
type_ids.extend([sample.get("type_id", 0)] * len(chunk))
num_new_samples, tokens_left = divmod(len(chunk), self.max_length_per_sample)
for _ in range(num_new_samples):
cu_seqlens.append(cu_seqlens[-1] + self.max_length_per_sample)
indexes.extend(list(range(self.max_length_per_sample)))
if tokens_left > 0:
cu_seqlens.append(cu_seqlens[-1] + tokens_left)
indexes.extend(list(range(tokens_left)))
out = {"tokens": pack, "cu_seqlens": cu_seqlens, "indexes": indexes, "labels": labels, "type_ids": type_ids}
return out
def __getitem__(self, item: int) -> Dict:
"""Given the index, it returns a dict as
{
'tokens': List[int],
'cu_seqlens': List[int],
'indexes': List[int], # denotes positional vector as 'tokens'
'labels': List[int], # corresponds to 'tokens' and shifted yet, -100 means skipping prediction
}
"""
pos_before, token_id_before, pos_after, token_id_after = self.mapping(item)
return self.build_pack(pos_before, token_id_before, pos_after, token_id_after)
class PackedDatasetWithoutCuSeqlen(torch.utils.data.Dataset):
"""
A dataset wrapper that aggregates samples with different lengths based on packed_length.
If a sample is shorter than max_length_per_sample, it will be merged with other samples.
For example, given a dataset with 10 samples:
[1, 2, 3, 4, 5]
[6, 7]
[8, 9, 10, 11]
[12, ..., 100]
...
Args:
dataset: The original dataset to be wrapped.
max_length_per_sample (int): The maximum length allowed for each sample.
packed_length (int): The desired length for each packed sample.
"""
def __init__(
self,
dataset,
max_length_per_sample: int = 2048,
packed_length: int = 4096,
debug=False,
):
assert packed_length % max_length_per_sample == 0
assert hasattr(dataset, "lengths")
assert len(getattr(dataset, "lengths")) == len(
dataset
), "The dataset must have lengths attribute and have the same length as the dataset"
self.dataset = dataset
self.max_length_per_sample = max_length_per_sample
self.lengths = getattr(self.dataset, "lengths")
self.bsz = packed_length // max_length_per_sample
self.packed_length = packed_length
self.debug = debug
# Force a seed to be fixed to prevent problems caused by the seed not being restored when restarting
self.seed = DEFAULT_SEED
indices = np.arange(len(self.lengths))
rng = np.random.RandomState(self.seed)
rng.shuffle(indices)
self.indices = indices
self.cum_lens = np.cumsum(self.lengths[self.indices])
self.num_tokens = sum(self.lengths)
def get_dataset_name(self):
return self.dataset.get_dataset_name()
def __len__(self):
n_packs = self.num_tokens // self.packed_length
return n_packs
def find_offset(self, offset):
idx = np.searchsorted(self.cum_lens, offset, side="right")
if idx == 0:
return idx, offset
length = offset - self.cum_lens[idx - 1]
return idx, length
def pdebug(self, line):
if self.debug:
print(line, flush=True)
def __getitem__(self, item: int) -> Dict:
"""Given the index, it returns a dict as
{
'tokens': List[int],
'cu_seqlens': List[int],
'indexes': List[int], # denotes positional vector as 'tokens'
'labels': List[int], # corresponds to 'tokens' and shifted yet, -100 means skipping prediction
}
"""
start_idx, start_length = self.find_offset(item * self.packed_length)
end_idx, end_length = self.find_offset((item + 1) * self.packed_length)
pack_tokens = []
pack_labels = []
type_ids = []
self.pdebug(f"item : {item}, start_idx:{start_idx}, start_length:{start_length} ")
self.pdebug(f"item : {item}, end_idx:{end_idx}, end_length:{end_length} ")
if start_idx == end_idx:
idx = self.indices[start_idx]
sample = self.dataset[idx]
self.pdebug(f"item : {item}, idx: {idx}, len : {len(sample['tokens'])}")
tokens = sample["tokens"][start_length:end_length]
pack_tokens.extend(tokens)
pack_labels.extend(tokens[1:] + [-100])
type_ids.extend([sample["type_id"]] * len(tokens))
return {
"tokens": pack_tokens,
"cu_seqlens": [i * self.max_length_per_sample for i in range(self.bsz + 1)],
"indexes": list(range(self.max_length_per_sample)) * self.bsz,
"labels": pack_labels,
"type_ids": type_ids,
}
idx = self.indices[start_idx]
sample = self.dataset[idx]
self.pdebug(f"item : {item}, idx: {idx}, len : {len(sample['tokens'])}")
tokens = sample["tokens"][start_length:]
pack_tokens.extend(tokens)
pack_labels.extend(tokens[1:] + [-100])
type_ids.extend([sample["type_id"]] * len(tokens))
for i in range(start_idx + 1, end_idx):
idx = self.indices[i]
sample = self.dataset[idx]
self.pdebug(f"item : {item}, idx: {idx}, len : {len(sample['tokens'])}")
tokens = sample["tokens"]
pack_tokens.extend(tokens)
pack_labels.extend(tokens[1:] + [-100])
type_ids.extend([sample.get("type_id")] * len(tokens))
# corner case, the last sample is useless
if end_length == 0:
pass
else:
idx = self.indices[end_idx]
sample = self.dataset[idx]
self.pdebug(f"item : {item}, idx: {idx}, len : {len(sample['tokens'])}")
tokens = sample["tokens"][:end_length]
pack_tokens.extend(tokens)
pack_labels.extend(tokens[1:] + [-100])
type_ids.extend([sample.get("type_id")] * len(tokens))
return {
"tokens": pack_tokens,
"cu_seqlens": [i * self.max_length_per_sample for i in range(self.bsz + 1)],
"indexes": list(range(self.max_length_per_sample)) * self.bsz,
"labels": pack_labels,
"type_ids": type_ids,
}
def get_packed_dataset_without_short_length(
folder,
max_length_per_sample=2048,
packed_length=4096,
show_progress=False,
min_length=50,
min_length_dict=None,
pack_into_one_sample=False,
):
"""
Given a folder, combine all the .bin files into a single large dataset.
And filter out short samples with length less than 'min_length'.
Each .bin file is treated as a separate dataset.
Args:
folder (str): Path to the folder containing the .bin files.
max_length_per_sample (int): Maximum length of each sample.
packed_length (int): Length to pack samples to.
show_progress (bool): Whether to show the progress bar.
min_length (int): The minimum length of the sample.
min_length_dict (dict): The minimum length of the sample for each dataset.
The format is something like {'pile-arxiv': 50}
dataset_backend (Optional[str]): Dataset storage location. Optional parameters are local, local-shm, kv
Returns:
A packed dataset containing all the data from the .bin files.
"""
assert os.path.exists(folder), f"{folder} does not exist."
datasets = []
delete_samples = 0
for root, dirs, files in os.walk(folder, followlinks=True):
dirs.sort() # Let the folder need to be returned in a fixed order
if gpc.is_rank_for_log():
logger.info(f"Reading {root}...")
num_token_in_folder = 0
for fn in tqdm(sorted(files), total=len(files), leave=False, disable=not show_progress):
if fn.endswith(".bin"):
fp = os.path.join(root, fn)
catch_ml_keys = []
min_length_num = min_length
if min_length_dict is not None:
for k, v in min_length_dict.items():
if k in fp:
min_length_num = v
catch_ml_keys.append(k)
assert (
len(catch_ml_keys) < 2
), f"The file name `{fp}` matched the following resample keys:{catch_ml_keys}"
ds_type_id = get_dataset_type_id(path=fp)
ds = JsonlDataset(fp, ds_type_id, min_length=min_length_num)
if hasattr(ds, "old_length"):
delete_samples += ds.old_length - len(ds)
if len(ds) == 0:
if gpc.is_rank_for_log():
logger.info(f"None of the data in `{fp}` is longer than {min_length}")
continue
if pack_into_one_sample:
ds = PackedDatasetWithoutCuSeqlen(ds, max_length_per_sample, packed_length)
else:
ds = PackedDataset(ds, max_length_per_sample, packed_length)
num_token_in_folder += len(ds) * packed_length
datasets.append(ds)
dataset = ConcatDataset(datasets=datasets)
if gpc.is_rank_for_log():
logger.info(
f"Find `{len(datasets)}` datasets, \
{len(dataset)} samples, \
delete `{delete_samples}` because of short length",
)
return dataset