mirror of https://github.com/InternLM/InternLM
feat(train): update get_train_data_loader to make logic clearer (#498)
* update get_train_data_loader * update get_train_data_loader, del old doc --------- Co-authored-by: YWMditto <862779238@qq.com>pull/484/head^2
parent
2b984ffa58
commit
be5b9ea2fa
|
@ -4,7 +4,7 @@
|
|||
import functools
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import Callable, Iterable, Union
|
||||
from typing import Callable, Iterable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -201,44 +201,37 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
|
|||
|
||||
|
||||
@llm_timeout(func_name="get_train_data_loader")
|
||||
def get_train_data_loader(
|
||||
num_worker: int = 0, dataset_generate_func: Callable = None, train_sampler=None, train_collate_fn=None
|
||||
):
|
||||
def get_train_data_loader(num_worker: int = 0, dataset_generate_func: Optional[Callable] = None):
|
||||
"""
|
||||
Generate and return the training data loader.
|
||||
|
||||
Args:
|
||||
num_worker (:class:`int`): number of subprocesses used for dataloader.
|
||||
dataset_generate_func (:class:`Callable`, optional): generate function for dataset.
|
||||
train_sampler (:class:`torch.utils.data.sampler`, optional): dataset sampler for training dataloader.
|
||||
train_collate_fn (:class:`Callable`, optional): collate function for training dataloader.
|
||||
|
||||
Returns:
|
||||
A tuple of (train_dl, dataset_types).
|
||||
"""
|
||||
|
||||
# Get the dataset types
|
||||
dataset_types = None
|
||||
data_cfg = gpc.config.data
|
||||
|
||||
# Get the sample weight dictionary
|
||||
train_folder = data_cfg.train_folder
|
||||
dataset_types = list(get_dataset_type_ids_map(train_folder).keys())
|
||||
|
||||
if not train_folder:
|
||||
dataset_types = ["en", "cn", "code"]
|
||||
train_ds = RandomDataset(num_samples=1000000, max_len=data_cfg.seq_len)
|
||||
if data_cfg.pack_sample_into_one:
|
||||
train_ds = PackedDatasetWithoutCuSeqlen(
|
||||
train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length
|
||||
)
|
||||
else:
|
||||
train_ds = PackedDataset(
|
||||
train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length
|
||||
)
|
||||
if dataset_generate_func is not None:
|
||||
train_ds, train_sampler, train_collate_fn = dataset_generate_func()
|
||||
else:
|
||||
if dataset_generate_func is not None:
|
||||
train_ds = dataset_generate_func()
|
||||
if train_folder is None:
|
||||
dataset_types = ["en", "cn", "code"]
|
||||
train_ds = RandomDataset(num_samples=1000000, max_len=data_cfg.seq_len)
|
||||
if data_cfg.pack_sample_into_one:
|
||||
train_ds = PackedDatasetWithoutCuSeqlen(
|
||||
train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length
|
||||
)
|
||||
else:
|
||||
train_ds = PackedDataset(
|
||||
train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length
|
||||
)
|
||||
else:
|
||||
train_ds = get_packed_dataset_without_short_length(
|
||||
folder=data_cfg.train_folder,
|
||||
|
@ -249,11 +242,6 @@ def get_train_data_loader(
|
|||
min_length_dict=data_cfg.get("min_length_dict", {}),
|
||||
pack_into_one_sample=data_cfg.pack_sample_into_one,
|
||||
)
|
||||
|
||||
if dataset_generate_func is None or not train_folder:
|
||||
# partition already completed
|
||||
assert isinstance(train_ds, (PackedDataset, PackedDatasetWithoutCuSeqlen, ConcatDataset))
|
||||
# Create the training dataset sampler
|
||||
train_sampler = StaticBatchSampler(
|
||||
train_ds.datasets if isinstance(train_ds, ConcatDataset) else [train_ds],
|
||||
batch_size=data_cfg.micro_num,
|
||||
|
@ -264,8 +252,6 @@ def get_train_data_loader(
|
|||
data_rank=gpc.get_local_rank(ParallelMode.DATA),
|
||||
data_world_size=gpc.get_world_size(ParallelMode.DATA),
|
||||
)
|
||||
|
||||
if dataset_generate_func is None or not train_folder:
|
||||
train_collate_fn = partial(packed_collate_fn, packed_length=data_cfg.packed_length)
|
||||
|
||||
# Create the training data loader
|
||||
|
|
Loading…
Reference in New Issue