update get_train_data_loader

pull/498/head
YWMditto 2023-11-14 15:39:43 +08:00
parent 2b984ffa58
commit f656ff08a6
1 changed files with 15 additions and 27 deletions

View File

@ -4,7 +4,7 @@
import functools import functools
import time import time
from functools import partial from functools import partial
from typing import Callable, Iterable, Union from typing import Callable, Iterable, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -201,9 +201,7 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
@llm_timeout(func_name="get_train_data_loader") @llm_timeout(func_name="get_train_data_loader")
def get_train_data_loader( def get_train_data_loader(num_worker: int = 0, dataset_generate_func: Optional[Callable] = None):
num_worker: int = 0, dataset_generate_func: Callable = None, train_sampler=None, train_collate_fn=None
):
""" """
Generate and return the training data loader. Generate and return the training data loader.
@ -218,27 +216,24 @@ def get_train_data_loader(
""" """
# Get the dataset types # Get the dataset types
dataset_types = None
data_cfg = gpc.config.data data_cfg = gpc.config.data
# Get the sample weight dictionary
train_folder = data_cfg.train_folder train_folder = data_cfg.train_folder
dataset_types = list(get_dataset_type_ids_map(train_folder).keys()) dataset_types = list(get_dataset_type_ids_map(train_folder).keys())
if not train_folder: if dataset_generate_func is not None:
dataset_types = ["en", "cn", "code"] train_ds, train_sampler, train_collate_fn = dataset_generate_func()
train_ds = RandomDataset(num_samples=1000000, max_len=data_cfg.seq_len)
if data_cfg.pack_sample_into_one:
train_ds = PackedDatasetWithoutCuSeqlen(
train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length
)
else:
train_ds = PackedDataset(
train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length
)
else: else:
if dataset_generate_func is not None: if train_folder is None:
train_ds = dataset_generate_func() dataset_types = ["en", "cn", "code"]
train_ds = RandomDataset(num_samples=1000000, max_len=data_cfg.seq_len)
if data_cfg.pack_sample_into_one:
train_ds = PackedDatasetWithoutCuSeqlen(
train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length
)
else:
train_ds = PackedDataset(
train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length
)
else: else:
train_ds = get_packed_dataset_without_short_length( train_ds = get_packed_dataset_without_short_length(
folder=data_cfg.train_folder, folder=data_cfg.train_folder,
@ -249,11 +244,6 @@ def get_train_data_loader(
min_length_dict=data_cfg.get("min_length_dict", {}), min_length_dict=data_cfg.get("min_length_dict", {}),
pack_into_one_sample=data_cfg.pack_sample_into_one, pack_into_one_sample=data_cfg.pack_sample_into_one,
) )
if dataset_generate_func is None or not train_folder:
# partition already completed
assert isinstance(train_ds, (PackedDataset, PackedDatasetWithoutCuSeqlen, ConcatDataset))
# Create the training dataset sampler
train_sampler = StaticBatchSampler( train_sampler = StaticBatchSampler(
train_ds.datasets if isinstance(train_ds, ConcatDataset) else [train_ds], train_ds.datasets if isinstance(train_ds, ConcatDataset) else [train_ds],
batch_size=data_cfg.micro_num, batch_size=data_cfg.micro_num,
@ -264,8 +254,6 @@ def get_train_data_loader(
data_rank=gpc.get_local_rank(ParallelMode.DATA), data_rank=gpc.get_local_rank(ParallelMode.DATA),
data_world_size=gpc.get_world_size(ParallelMode.DATA), data_world_size=gpc.get_world_size(ParallelMode.DATA),
) )
if dataset_generate_func is None or not train_folder:
train_collate_fn = partial(packed_collate_fn, packed_length=data_cfg.packed_length) train_collate_fn = partial(packed_collate_fn, packed_length=data_cfg.packed_length)
# Create the training data loader # Create the training data loader