mirror of https://github.com/InternLM/InternLM
56 lines
1.8 KiB
Python
56 lines
1.8 KiB
Python
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
import os
|
|
import re
|
|
|
|
import torch
|
|
|
|
from internlm.core.context import global_context as gpc
|
|
|
|
|
|
def get_dataset_type_ids_map(path):
|
|
dirlist = list(os.listdir(path))
|
|
dirlist.sort()
|
|
return {key: idx for idx, key in enumerate(dirlist)}
|
|
|
|
|
|
def get_dataset_type_id(dataset_type_ids_map, path):
|
|
match_idxes = []
|
|
|
|
for key, idx in dataset_type_ids_map.items():
|
|
if re.search(rf"/[z_]*{key}/", path):
|
|
match_idxes.append(idx)
|
|
assert len(match_idxes) == 1, f"{path}, match_idxes should be 1, but got {match_idxes} from {dataset_type_ids_map}"
|
|
return match_idxes[0]
|
|
|
|
|
|
def unpack_data(input_ids, cu_seqlens, is_type_ids: bool = False):
|
|
"""
|
|
input_ids: if input_ids is not type_ids, the shape is (1, packed_length)
|
|
else the shape is (micro_num, packed_length)
|
|
is_type_ids: whether the input_ids is type_ids
|
|
|
|
Return:
|
|
output: if input_ids is not type ids, the shape is (micro_bsz, max_length)
|
|
else the shape is (micro_num, micro_bsz, max_length)
|
|
"""
|
|
bsz = input_ids.shape[0]
|
|
|
|
num_sequence = gpc.config.data["micro_bsz"]
|
|
|
|
outputs = torch.zeros(bsz, num_sequence, gpc.config.data.seq_len, device=input_ids.device, dtype=input_ids.dtype)
|
|
|
|
for i in range(bsz):
|
|
output = torch.zeros(num_sequence, gpc.config.data.seq_len, device=input_ids.device, dtype=input_ids.dtype)
|
|
cu_seqlens_slice = cu_seqlens[i]
|
|
for j in range(num_sequence):
|
|
seq_length = cu_seqlens_slice[j + 1] - cu_seqlens_slice[j]
|
|
output[j, 0:seq_length] = input_ids[0, cu_seqlens_slice[j] : cu_seqlens_slice[j + 1]]
|
|
outputs[i] = output
|
|
|
|
# if the input_ids is not type_ids, we need squeeze the first dimension if it is 1.
|
|
if bsz == 1 and not is_type_ids:
|
|
outputs = outputs.squeeze(0)
|
|
|
|
return outputs
|