feat(train): update get_train_data_loader to make logic clearer (#498)

* update get_train_data_loader

* update get_train_data_loader, del old doc

---------

Co-authored-by: YWMditto <862779238@qq.com>
pull/484/head^2
YWMditto 2023-11-14 17:05:15 +08:00 committed by GitHub
parent 2b984ffa58
commit be5b9ea2fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 15 additions and 29 deletions

View File

@ -4,7 +4,7 @@
import functools
import time
from functools import partial
from typing import Callable, Iterable, Union
from typing import Callable, Iterable, Optional, Union
import torch
import torch.distributed as dist
@ -201,44 +201,37 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
@llm_timeout(func_name="get_train_data_loader")
def get_train_data_loader(
num_worker: int = 0, dataset_generate_func: Callable = None, train_sampler=None, train_collate_fn=None
):
def get_train_data_loader(num_worker: int = 0, dataset_generate_func: Optional[Callable] = None):
"""
Generate and return the training data loader.
Args:
num_worker (:class:`int`): number of subprocesses used for dataloader.
dataset_generate_func (:class:`Callable`, optional): generate function for dataset.
train_sampler (:class:`torch.utils.data.sampler`, optional): dataset sampler for training dataloader.
train_collate_fn (:class:`Callable`, optional): collate function for training dataloader.
Returns:
A tuple of (train_dl, dataset_types).
"""
# Get the dataset types
dataset_types = None
data_cfg = gpc.config.data
# Get the sample weight dictionary
train_folder = data_cfg.train_folder
dataset_types = list(get_dataset_type_ids_map(train_folder).keys())
if not train_folder:
dataset_types = ["en", "cn", "code"]
train_ds = RandomDataset(num_samples=1000000, max_len=data_cfg.seq_len)
if data_cfg.pack_sample_into_one:
train_ds = PackedDatasetWithoutCuSeqlen(
train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length
)
else:
train_ds = PackedDataset(
train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length
)
if dataset_generate_func is not None:
train_ds, train_sampler, train_collate_fn = dataset_generate_func()
else:
if dataset_generate_func is not None:
train_ds = dataset_generate_func()
if train_folder is None:
dataset_types = ["en", "cn", "code"]
train_ds = RandomDataset(num_samples=1000000, max_len=data_cfg.seq_len)
if data_cfg.pack_sample_into_one:
train_ds = PackedDatasetWithoutCuSeqlen(
train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length
)
else:
train_ds = PackedDataset(
train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length
)
else:
train_ds = get_packed_dataset_without_short_length(
folder=data_cfg.train_folder,
@ -249,11 +242,6 @@ def get_train_data_loader(
min_length_dict=data_cfg.get("min_length_dict", {}),
pack_into_one_sample=data_cfg.pack_sample_into_one,
)
if dataset_generate_func is None or not train_folder:
# partition already completed
assert isinstance(train_ds, (PackedDataset, PackedDatasetWithoutCuSeqlen, ConcatDataset))
# Create the training dataset sampler
train_sampler = StaticBatchSampler(
train_ds.datasets if isinstance(train_ds, ConcatDataset) else [train_ds],
batch_size=data_cfg.micro_num,
@ -264,8 +252,6 @@ def get_train_data_loader(
data_rank=gpc.get_local_rank(ParallelMode.DATA),
data_world_size=gpc.get_world_size(ParallelMode.DATA),
)
if dataset_generate_func is None or not train_folder:
train_collate_fn = partial(packed_collate_fn, packed_length=data_cfg.packed_length)
# Create the training data loader