mirror of https://github.com/InternLM/InternLM
feat(data): walk folder to get dataset_type_ids_map (#477)
* walk folder to get dataset_type_ids_map * fix a bugpull/479/head
parent
4d1b1cd5f1
commit
6f69bd2087
|
@ -15,7 +15,7 @@ from tqdm import tqdm
|
||||||
|
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.data.single_dataset import JsonlDataset
|
from internlm.data.single_dataset import JsonlDataset
|
||||||
from internlm.data.utils import get_dataset_type_id
|
from internlm.data.utils import get_dataset_type_id, get_dataset_type_ids_map
|
||||||
from internlm.utils.logger import get_logger
|
from internlm.utils.logger import get_logger
|
||||||
|
|
||||||
DEFAULT_SEED = 1024
|
DEFAULT_SEED = 1024
|
||||||
|
@ -373,6 +373,8 @@ def get_packed_dataset_without_short_length(
|
||||||
datasets = []
|
datasets = []
|
||||||
delete_samples = 0
|
delete_samples = 0
|
||||||
|
|
||||||
|
DATASET_TYPE_IDS_MAP = get_dataset_type_ids_map(folder)
|
||||||
|
|
||||||
if gpc.get_global_rank() == 0:
|
if gpc.get_global_rank() == 0:
|
||||||
triples = [list(os.walk(folder, followlinks=True))]
|
triples = [list(os.walk(folder, followlinks=True))]
|
||||||
else:
|
else:
|
||||||
|
@ -400,7 +402,7 @@ def get_packed_dataset_without_short_length(
|
||||||
len(catch_ml_keys) < 2
|
len(catch_ml_keys) < 2
|
||||||
), f"The file name `{fp}` matched the following resample keys:{catch_ml_keys}"
|
), f"The file name `{fp}` matched the following resample keys:{catch_ml_keys}"
|
||||||
|
|
||||||
ds_type_id = get_dataset_type_id(path=fp)
|
ds_type_id = get_dataset_type_id(DATASET_TYPE_IDS_MAP, path=fp)
|
||||||
ds = JsonlDataset(fp, ds_type_id, min_length=min_length_num)
|
ds = JsonlDataset(fp, ds_type_id, min_length=min_length_num)
|
||||||
|
|
||||||
if hasattr(ds, "old_length"):
|
if hasattr(ds, "old_length"):
|
||||||
|
|
|
@ -1,21 +1,26 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
|
|
||||||
DATASET_TYPE_IDS_MAP = {"en": 0, "cn": 1, "code": 2}
|
|
||||||
|
def get_dataset_type_ids_map(path):
|
||||||
|
dirlist = list(os.listdir(path))
|
||||||
|
dirlist.sort()
|
||||||
|
return {key: idx for idx, key in enumerate(dirlist)}
|
||||||
|
|
||||||
|
|
||||||
def get_dataset_type_id(path):
|
def get_dataset_type_id(dataset_type_ids_map, path):
|
||||||
import re
|
|
||||||
|
|
||||||
match_idxes = []
|
match_idxes = []
|
||||||
for key, idx in DATASET_TYPE_IDS_MAP.items():
|
|
||||||
|
for key, idx in dataset_type_ids_map.items():
|
||||||
if re.search(rf"/[z_]*{key}/", path):
|
if re.search(rf"/[z_]*{key}/", path):
|
||||||
match_idxes.append(idx)
|
match_idxes.append(idx)
|
||||||
assert len(match_idxes) == 1, f"{path}, match_idxes should be 1, but got {match_idxes} from {DATASET_TYPE_IDS_MAP}"
|
assert len(match_idxes) == 1, f"{path}, match_idxes should be 1, but got {match_idxes} from {dataset_type_ids_map}"
|
||||||
return match_idxes[0]
|
return match_idxes[0]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,7 @@ from internlm.data.packed_dataset import (
|
||||||
PackedDatasetWithoutCuSeqlen,
|
PackedDatasetWithoutCuSeqlen,
|
||||||
get_packed_dataset_without_short_length,
|
get_packed_dataset_without_short_length,
|
||||||
)
|
)
|
||||||
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
|
from internlm.data.utils import get_dataset_type_ids_map, unpack_data
|
||||||
from internlm.model.embedding import Embedding1D
|
from internlm.model.embedding import Embedding1D
|
||||||
from internlm.model.linear import (
|
from internlm.model.linear import (
|
||||||
FeedForward,
|
FeedForward,
|
||||||
|
@ -133,7 +133,7 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
|
||||||
|
|
||||||
# wrap the model
|
# wrap the model
|
||||||
grp = gpc.get_group(ParallelMode.ZERO1)
|
grp = gpc.get_group(ParallelMode.ZERO1)
|
||||||
model = FSDP(
|
model = FSDP( # pylint: disable=unexpected-keyword-arg
|
||||||
module=model,
|
module=model,
|
||||||
process_group=grp,
|
process_group=grp,
|
||||||
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
||||||
|
@ -219,11 +219,11 @@ def get_train_data_loader(
|
||||||
|
|
||||||
# Get the dataset types
|
# Get the dataset types
|
||||||
dataset_types = None
|
dataset_types = None
|
||||||
dataset_types = list(DATASET_TYPE_IDS_MAP.keys())
|
|
||||||
data_cfg = gpc.config.data
|
data_cfg = gpc.config.data
|
||||||
|
|
||||||
# Get the sample weight dictionary
|
# Get the sample weight dictionary
|
||||||
train_folder = data_cfg.train_folder
|
train_folder = data_cfg.train_folder
|
||||||
|
dataset_types = list(get_dataset_type_ids_map(train_folder).keys())
|
||||||
|
|
||||||
if not train_folder:
|
if not train_folder:
|
||||||
train_ds = RandomDataset(num_samples=1000000, max_len=data_cfg.seq_len)
|
train_ds = RandomDataset(num_samples=1000000, max_len=data_cfg.seq_len)
|
||||||
|
|
Loading…
Reference in New Issue