feat(train): update get_train_data_loader to make logic clearer (#498)

* update get_train_data_loader

* update get_train_data_loader, del old doc

---------

Co-authored-by: YWMditto <862779238@qq.com>
pull/484/head^2
YWMditto 2023-11-14 17:05:15 +08:00 committed by GitHub
parent 2b984ffa58
commit be5b9ea2fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 15 additions and 29 deletions

View File

@ -4,7 +4,7 @@
import functools import functools
import time import time
from functools import partial from functools import partial
from typing import Callable, Iterable, Union from typing import Callable, Iterable, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -201,44 +201,37 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
@llm_timeout(func_name="get_train_data_loader") @llm_timeout(func_name="get_train_data_loader")
def get_train_data_loader( def get_train_data_loader(num_worker: int = 0, dataset_generate_func: Optional[Callable] = None):
num_worker: int = 0, dataset_generate_func: Callable = None, train_sampler=None, train_collate_fn=None
):
""" """
Generate and return the training data loader. Generate and return the training data loader.
Args: Args:
num_worker (:class:`int`): number of subprocesses used for dataloader. num_worker (:class:`int`): number of subprocesses used for dataloader.
dataset_generate_func (:class:`Callable`, optional): generate function for dataset. dataset_generate_func (:class:`Callable`, optional): generate function for dataset.
train_sampler (:class:`torch.utils.data.sampler`, optional): dataset sampler for training dataloader.
train_collate_fn (:class:`Callable`, optional): collate function for training dataloader.
Returns: Returns:
A tuple of (train_dl, dataset_types). A tuple of (train_dl, dataset_types).
""" """
# Get the dataset types # Get the dataset types
dataset_types = None
data_cfg = gpc.config.data data_cfg = gpc.config.data
# Get the sample weight dictionary
train_folder = data_cfg.train_folder train_folder = data_cfg.train_folder
dataset_types = list(get_dataset_type_ids_map(train_folder).keys()) dataset_types = list(get_dataset_type_ids_map(train_folder).keys())
if not train_folder: if dataset_generate_func is not None:
dataset_types = ["en", "cn", "code"] train_ds, train_sampler, train_collate_fn = dataset_generate_func()
train_ds = RandomDataset(num_samples=1000000, max_len=data_cfg.seq_len)
if data_cfg.pack_sample_into_one:
train_ds = PackedDatasetWithoutCuSeqlen(
train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length
)
else:
train_ds = PackedDataset(
train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length
)
else: else:
if dataset_generate_func is not None: if train_folder is None:
train_ds = dataset_generate_func() dataset_types = ["en", "cn", "code"]
train_ds = RandomDataset(num_samples=1000000, max_len=data_cfg.seq_len)
if data_cfg.pack_sample_into_one:
train_ds = PackedDatasetWithoutCuSeqlen(
train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length
)
else:
train_ds = PackedDataset(
train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length
)
else: else:
train_ds = get_packed_dataset_without_short_length( train_ds = get_packed_dataset_without_short_length(
folder=data_cfg.train_folder, folder=data_cfg.train_folder,
@ -249,11 +242,6 @@ def get_train_data_loader(
min_length_dict=data_cfg.get("min_length_dict", {}), min_length_dict=data_cfg.get("min_length_dict", {}),
pack_into_one_sample=data_cfg.pack_sample_into_one, pack_into_one_sample=data_cfg.pack_sample_into_one,
) )
if dataset_generate_func is None or not train_folder:
# partition already completed
assert isinstance(train_ds, (PackedDataset, PackedDatasetWithoutCuSeqlen, ConcatDataset))
# Create the training dataset sampler
train_sampler = StaticBatchSampler( train_sampler = StaticBatchSampler(
train_ds.datasets if isinstance(train_ds, ConcatDataset) else [train_ds], train_ds.datasets if isinstance(train_ds, ConcatDataset) else [train_ds],
batch_size=data_cfg.micro_num, batch_size=data_cfg.micro_num,
@ -264,8 +252,6 @@ def get_train_data_loader(
data_rank=gpc.get_local_rank(ParallelMode.DATA), data_rank=gpc.get_local_rank(ParallelMode.DATA),
data_world_size=gpc.get_world_size(ParallelMode.DATA), data_world_size=gpc.get_world_size(ParallelMode.DATA),
) )
if dataset_generate_func is None or not train_folder:
train_collate_fn = partial(packed_collate_fn, packed_length=data_cfg.packed_length) train_collate_fn = partial(packed_collate_fn, packed_length=data_cfg.packed_length)
# Create the training data loader # Create the training data loader