feat(data): walk folder to get dataset_type_ids_map (#477)

* walk folder to get dataset_type_ids_map

* fix a bug
pull/479/head
Yang Gao 2023-11-07 19:21:10 +08:00 committed by GitHub
parent 4d1b1cd5f1
commit 6f69bd2087
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 18 additions and 11 deletions

View File

@ -15,7 +15,7 @@ from tqdm import tqdm
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.data.single_dataset import JsonlDataset from internlm.data.single_dataset import JsonlDataset
from internlm.data.utils import get_dataset_type_id from internlm.data.utils import get_dataset_type_id, get_dataset_type_ids_map
from internlm.utils.logger import get_logger from internlm.utils.logger import get_logger
DEFAULT_SEED = 1024 DEFAULT_SEED = 1024
@ -373,6 +373,8 @@ def get_packed_dataset_without_short_length(
datasets = [] datasets = []
delete_samples = 0 delete_samples = 0
DATASET_TYPE_IDS_MAP = get_dataset_type_ids_map(folder)
if gpc.get_global_rank() == 0: if gpc.get_global_rank() == 0:
triples = [list(os.walk(folder, followlinks=True))] triples = [list(os.walk(folder, followlinks=True))]
else: else:
@ -400,7 +402,7 @@ def get_packed_dataset_without_short_length(
len(catch_ml_keys) < 2 len(catch_ml_keys) < 2
), f"The file name `{fp}` matched the following resample keys:{catch_ml_keys}" ), f"The file name `{fp}` matched the following resample keys:{catch_ml_keys}"
ds_type_id = get_dataset_type_id(path=fp) ds_type_id = get_dataset_type_id(DATASET_TYPE_IDS_MAP, path=fp)
ds = JsonlDataset(fp, ds_type_id, min_length=min_length_num) ds = JsonlDataset(fp, ds_type_id, min_length=min_length_num)
if hasattr(ds, "old_length"): if hasattr(ds, "old_length"):

View File

@ -1,21 +1,26 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import os
import re
import torch import torch
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
DATASET_TYPE_IDS_MAP = {"en": 0, "cn": 1, "code": 2}
def get_dataset_type_ids_map(path):
dirlist = list(os.listdir(path))
dirlist.sort()
return {key: idx for idx, key in enumerate(dirlist)}
def get_dataset_type_id(path): def get_dataset_type_id(dataset_type_ids_map, path):
import re
match_idxes = [] match_idxes = []
for key, idx in DATASET_TYPE_IDS_MAP.items():
for key, idx in dataset_type_ids_map.items():
if re.search(rf"/[z_]*{key}/", path): if re.search(rf"/[z_]*{key}/", path):
match_idxes.append(idx) match_idxes.append(idx)
assert len(match_idxes) == 1, f"{path}, match_idxes should be 1, but got {match_idxes} from {DATASET_TYPE_IDS_MAP}" assert len(match_idxes) == 1, f"{path}, match_idxes should be 1, but got {match_idxes} from {dataset_type_ids_map}"
return match_idxes[0] return match_idxes[0]

View File

@ -33,7 +33,7 @@ from internlm.data.packed_dataset import (
PackedDatasetWithoutCuSeqlen, PackedDatasetWithoutCuSeqlen,
get_packed_dataset_without_short_length, get_packed_dataset_without_short_length,
) )
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data from internlm.data.utils import get_dataset_type_ids_map, unpack_data
from internlm.model.embedding import Embedding1D from internlm.model.embedding import Embedding1D
from internlm.model.linear import ( from internlm.model.linear import (
FeedForward, FeedForward,
@ -133,7 +133,7 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
# wrap the model # wrap the model
grp = gpc.get_group(ParallelMode.ZERO1) grp = gpc.get_group(ParallelMode.ZERO1)
model = FSDP( model = FSDP( # pylint: disable=unexpected-keyword-arg
module=model, module=model,
process_group=grp, process_group=grp,
sharding_strategy=ShardingStrategy.FULL_SHARD, sharding_strategy=ShardingStrategy.FULL_SHARD,
@ -219,11 +219,11 @@ def get_train_data_loader(
# Get the dataset types # Get the dataset types
dataset_types = None dataset_types = None
dataset_types = list(DATASET_TYPE_IDS_MAP.keys())
data_cfg = gpc.config.data data_cfg = gpc.config.data
# Get the sample weight dictionary # Get the sample weight dictionary
train_folder = data_cfg.train_folder train_folder = data_cfg.train_folder
dataset_types = list(get_dataset_type_ids_map(train_folder).keys())
if not train_folder: if not train_folder:
train_ds = RandomDataset(num_samples=1000000, max_len=data_cfg.seq_len) train_ds = RandomDataset(num_samples=1000000, max_len=data_cfg.seq_len)