mirror of https://github.com/InternLM/InternLM
update get_train_data_loader
parent
2b984ffa58
commit
f656ff08a6
|
@ -4,7 +4,7 @@
|
||||||
import functools
|
import functools
|
||||||
import time
|
import time
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable, Iterable, Union
|
from typing import Callable, Iterable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
@ -201,9 +201,7 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
|
||||||
|
|
||||||
|
|
||||||
@llm_timeout(func_name="get_train_data_loader")
|
@llm_timeout(func_name="get_train_data_loader")
|
||||||
def get_train_data_loader(
|
def get_train_data_loader(num_worker: int = 0, dataset_generate_func: Optional[Callable] = None):
|
||||||
num_worker: int = 0, dataset_generate_func: Callable = None, train_sampler=None, train_collate_fn=None
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Generate and return the training data loader.
|
Generate and return the training data loader.
|
||||||
|
|
||||||
|
@ -218,14 +216,14 @@ def get_train_data_loader(
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Get the dataset types
|
# Get the dataset types
|
||||||
dataset_types = None
|
|
||||||
data_cfg = gpc.config.data
|
data_cfg = gpc.config.data
|
||||||
|
|
||||||
# Get the sample weight dictionary
|
|
||||||
train_folder = data_cfg.train_folder
|
train_folder = data_cfg.train_folder
|
||||||
dataset_types = list(get_dataset_type_ids_map(train_folder).keys())
|
dataset_types = list(get_dataset_type_ids_map(train_folder).keys())
|
||||||
|
|
||||||
if not train_folder:
|
if dataset_generate_func is not None:
|
||||||
|
train_ds, train_sampler, train_collate_fn = dataset_generate_func()
|
||||||
|
else:
|
||||||
|
if train_folder is None:
|
||||||
dataset_types = ["en", "cn", "code"]
|
dataset_types = ["en", "cn", "code"]
|
||||||
train_ds = RandomDataset(num_samples=1000000, max_len=data_cfg.seq_len)
|
train_ds = RandomDataset(num_samples=1000000, max_len=data_cfg.seq_len)
|
||||||
if data_cfg.pack_sample_into_one:
|
if data_cfg.pack_sample_into_one:
|
||||||
|
@ -236,9 +234,6 @@ def get_train_data_loader(
|
||||||
train_ds = PackedDataset(
|
train_ds = PackedDataset(
|
||||||
train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length
|
train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
if dataset_generate_func is not None:
|
|
||||||
train_ds = dataset_generate_func()
|
|
||||||
else:
|
else:
|
||||||
train_ds = get_packed_dataset_without_short_length(
|
train_ds = get_packed_dataset_without_short_length(
|
||||||
folder=data_cfg.train_folder,
|
folder=data_cfg.train_folder,
|
||||||
|
@ -249,11 +244,6 @@ def get_train_data_loader(
|
||||||
min_length_dict=data_cfg.get("min_length_dict", {}),
|
min_length_dict=data_cfg.get("min_length_dict", {}),
|
||||||
pack_into_one_sample=data_cfg.pack_sample_into_one,
|
pack_into_one_sample=data_cfg.pack_sample_into_one,
|
||||||
)
|
)
|
||||||
|
|
||||||
if dataset_generate_func is None or not train_folder:
|
|
||||||
# partition already completed
|
|
||||||
assert isinstance(train_ds, (PackedDataset, PackedDatasetWithoutCuSeqlen, ConcatDataset))
|
|
||||||
# Create the training dataset sampler
|
|
||||||
train_sampler = StaticBatchSampler(
|
train_sampler = StaticBatchSampler(
|
||||||
train_ds.datasets if isinstance(train_ds, ConcatDataset) else [train_ds],
|
train_ds.datasets if isinstance(train_ds, ConcatDataset) else [train_ds],
|
||||||
batch_size=data_cfg.micro_num,
|
batch_size=data_cfg.micro_num,
|
||||||
|
@ -264,8 +254,6 @@ def get_train_data_loader(
|
||||||
data_rank=gpc.get_local_rank(ParallelMode.DATA),
|
data_rank=gpc.get_local_rank(ParallelMode.DATA),
|
||||||
data_world_size=gpc.get_world_size(ParallelMode.DATA),
|
data_world_size=gpc.get_world_size(ParallelMode.DATA),
|
||||||
)
|
)
|
||||||
|
|
||||||
if dataset_generate_func is None or not train_folder:
|
|
||||||
train_collate_fn = partial(packed_collate_fn, packed_length=data_cfg.packed_length)
|
train_collate_fn = partial(packed_collate_fn, packed_length=data_cfg.packed_length)
|
||||||
|
|
||||||
# Create the training data loader
|
# Create the training data loader
|
||||||
|
|
Loading…
Reference in New Issue