mirror of https://github.com/InternLM/InternLM
feat(train): update get_train_data_loader to make logic clearer (#498)
* update get_train_data_loader * update get_train_data_loader, del old doc --------- Co-authored-by: YWMditto <862779238@qq.com>pull/484/head^2
parent
2b984ffa58
commit
be5b9ea2fa
|
@ -4,7 +4,7 @@
|
||||||
import functools
|
import functools
|
||||||
import time
|
import time
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable, Iterable, Union
|
from typing import Callable, Iterable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
@ -201,31 +201,27 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
|
||||||
|
|
||||||
|
|
||||||
@llm_timeout(func_name="get_train_data_loader")
|
@llm_timeout(func_name="get_train_data_loader")
|
||||||
def get_train_data_loader(
|
def get_train_data_loader(num_worker: int = 0, dataset_generate_func: Optional[Callable] = None):
|
||||||
num_worker: int = 0, dataset_generate_func: Callable = None, train_sampler=None, train_collate_fn=None
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Generate and return the training data loader.
|
Generate and return the training data loader.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_worker (:class:`int`): number of subprocesses used for dataloader.
|
num_worker (:class:`int`): number of subprocesses used for dataloader.
|
||||||
dataset_generate_func (:class:`Callable`, optional): generate function for dataset.
|
dataset_generate_func (:class:`Callable`, optional): generate function for dataset.
|
||||||
train_sampler (:class:`torch.utils.data.sampler`, optional): dataset sampler for training dataloader.
|
|
||||||
train_collate_fn (:class:`Callable`, optional): collate function for training dataloader.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of (train_dl, dataset_types).
|
A tuple of (train_dl, dataset_types).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Get the dataset types
|
# Get the dataset types
|
||||||
dataset_types = None
|
|
||||||
data_cfg = gpc.config.data
|
data_cfg = gpc.config.data
|
||||||
|
|
||||||
# Get the sample weight dictionary
|
|
||||||
train_folder = data_cfg.train_folder
|
train_folder = data_cfg.train_folder
|
||||||
dataset_types = list(get_dataset_type_ids_map(train_folder).keys())
|
dataset_types = list(get_dataset_type_ids_map(train_folder).keys())
|
||||||
|
|
||||||
if not train_folder:
|
if dataset_generate_func is not None:
|
||||||
|
train_ds, train_sampler, train_collate_fn = dataset_generate_func()
|
||||||
|
else:
|
||||||
|
if train_folder is None:
|
||||||
dataset_types = ["en", "cn", "code"]
|
dataset_types = ["en", "cn", "code"]
|
||||||
train_ds = RandomDataset(num_samples=1000000, max_len=data_cfg.seq_len)
|
train_ds = RandomDataset(num_samples=1000000, max_len=data_cfg.seq_len)
|
||||||
if data_cfg.pack_sample_into_one:
|
if data_cfg.pack_sample_into_one:
|
||||||
|
@ -236,9 +232,6 @@ def get_train_data_loader(
|
||||||
train_ds = PackedDataset(
|
train_ds = PackedDataset(
|
||||||
train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length
|
train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
if dataset_generate_func is not None:
|
|
||||||
train_ds = dataset_generate_func()
|
|
||||||
else:
|
else:
|
||||||
train_ds = get_packed_dataset_without_short_length(
|
train_ds = get_packed_dataset_without_short_length(
|
||||||
folder=data_cfg.train_folder,
|
folder=data_cfg.train_folder,
|
||||||
|
@ -249,11 +242,6 @@ def get_train_data_loader(
|
||||||
min_length_dict=data_cfg.get("min_length_dict", {}),
|
min_length_dict=data_cfg.get("min_length_dict", {}),
|
||||||
pack_into_one_sample=data_cfg.pack_sample_into_one,
|
pack_into_one_sample=data_cfg.pack_sample_into_one,
|
||||||
)
|
)
|
||||||
|
|
||||||
if dataset_generate_func is None or not train_folder:
|
|
||||||
# partition already completed
|
|
||||||
assert isinstance(train_ds, (PackedDataset, PackedDatasetWithoutCuSeqlen, ConcatDataset))
|
|
||||||
# Create the training dataset sampler
|
|
||||||
train_sampler = StaticBatchSampler(
|
train_sampler = StaticBatchSampler(
|
||||||
train_ds.datasets if isinstance(train_ds, ConcatDataset) else [train_ds],
|
train_ds.datasets if isinstance(train_ds, ConcatDataset) else [train_ds],
|
||||||
batch_size=data_cfg.micro_num,
|
batch_size=data_cfg.micro_num,
|
||||||
|
@ -264,8 +252,6 @@ def get_train_data_loader(
|
||||||
data_rank=gpc.get_local_rank(ParallelMode.DATA),
|
data_rank=gpc.get_local_rank(ParallelMode.DATA),
|
||||||
data_world_size=gpc.get_world_size(ParallelMode.DATA),
|
data_world_size=gpc.get_world_size(ParallelMode.DATA),
|
||||||
)
|
)
|
||||||
|
|
||||||
if dataset_generate_func is None or not train_folder:
|
|
||||||
train_collate_fn = partial(packed_collate_fn, packed_length=data_cfg.packed_length)
|
train_collate_fn = partial(packed_collate_fn, packed_length=data_cfg.packed_length)
|
||||||
|
|
||||||
# Create the training data loader
|
# Create the training data loader
|
||||||
|
|
Loading…
Reference in New Issue