diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py b/applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py index a2cfb2ef6..5115e4563 100644 --- a/applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py +++ b/applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py @@ -1,20 +1,16 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -import numpy as np import os -import random from dataclasses import dataclass -from typing import Dict, List, Union, Sequence, Optional, Iterator, Callable +from typing import Dict, Iterator, List, Optional, Sequence, Union import torch -from datasets import dataset_dict, load_from_disk -from datasets import Dataset as HFDataset -from torch.distributed import ProcessGroup -from torch.distributed.distributed_c10d import _get_default_group -from torch.utils.data import ConcatDataset, Dataset, DataLoader, DistributedSampler -from transformers.tokenization_utils import PreTrainedTokenizer import torch.nn.functional as F +from datasets import Dataset as HFDataset +from datasets import dataset_dict, load_from_disk +from torch.utils.data import ConcatDataset, Dataset, DistributedSampler +from transformers.tokenization_utils import PreTrainedTokenizer DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset] PathType = Union[str, os.PathLike] @@ -171,49 +167,3 @@ class StatefulDistributedSampler(DistributedSampler): def set_start_index(self, start_index: int) -> None: self.start_index = start_index - - -def setup_distributed_dataloader( - dataset: DatasetType, - batch_size: int = 1, - shuffle: bool = False, - seed: int = 1024, - drop_last: bool = False, - pin_memory: bool = False, - num_workers: int = 0, - collate_fn: Callable[[Sequence[Dict[str, Union[str, List[int]]]]], Dict[str, torch.Tensor]] = None, - process_group: Optional[ProcessGroup] = None, - **kwargs, -) -> DataLoader: - """ - Setup dataloader for distributed training. - """ - _kwargs = kwargs.copy() - process_group = process_group or _get_default_group() - sampler = StatefulDistributedSampler( - dataset=dataset, - num_replicas=process_group.size(), - rank=process_group.rank(), - shuffle=shuffle, - seed=seed, - drop_last=drop_last, - ) - - # Deterministic dataloader - def seed_worker(worker_id: int) -> None: - worker_seed = seed - np.random.seed(worker_seed) - torch.manual_seed(worker_seed) - random.seed(worker_seed) - - return DataLoader( - dataset=dataset, - batch_size=batch_size, - sampler=sampler, - num_workers=num_workers, - collate_fn=collate_fn, - pin_memory=pin_memory, - drop_last=drop_last, - worker_init_fn=seed_worker, - **_kwargs, - ) diff --git a/applications/Colossal-LLaMA-2/train.py b/applications/Colossal-LLaMA-2/train.py index 92863e8e4..4aecd46d7 100644 --- a/applications/Colossal-LLaMA-2/train.py +++ b/applications/Colossal-LLaMA-2/train.py @@ -16,7 +16,6 @@ from colossal_llama2.dataset.loader import ( DataCollatorForSupervisedDataset, StatefulDistributedSampler, load_tokenized_dataset, - setup_distributed_dataloader, ) from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention @@ -194,12 +193,13 @@ def main() -> None: dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train") data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length) - dataloader = setup_distributed_dataloader( + dataloader = plugin.prepare_dataloader( dataset=dataset, batch_size=args.micro_batch_size, shuffle=True, drop_last=True, collate_fn=data_collator, + distributed_sampler_cls=StatefulDistributedSampler, ) coordinator.print_on_master( f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB" diff --git a/applications/Colossal-LLaMA-2/train_sft.py b/applications/Colossal-LLaMA-2/train_sft.py index fd9e1cd3e..27a7ce096 100644 --- a/applications/Colossal-LLaMA-2/train_sft.py +++ b/applications/Colossal-LLaMA-2/train_sft.py @@ -16,7 +16,6 @@ from colossal_llama2.dataset.loader import ( DataCollatorForSupervisedDataset, StatefulDistributedSampler, load_tokenized_dataset, - setup_distributed_dataloader, ) from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention @@ -203,12 +202,13 @@ def main() -> None: dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train") data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length) - dataloader = setup_distributed_dataloader( + dataloader = plugin.prepare_dataloader( dataset=dataset, batch_size=args.micro_batch_size, shuffle=True, drop_last=True, collate_fn=data_collator, + distributed_sampler_cls=StatefulDistributedSampler, ) coordinator.print_on_master( f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB" diff --git a/colossalai/booster/plugin/dp_plugin_base.py b/colossalai/booster/plugin/dp_plugin_base.py index d2dd00453..27285f95c 100644 --- a/colossalai/booster/plugin/dp_plugin_base.py +++ b/colossalai/booster/plugin/dp_plugin_base.py @@ -21,7 +21,16 @@ class DPPluginBase(Plugin): self.world_size = dist.get_world_size() def prepare_dataloader( - self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs + self, + dataset, + batch_size, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + distributed_sampler_cls=None, + **kwargs, ): r""" Prepare a dataloader for distributed training. The dataloader will be wrapped by @@ -45,7 +54,8 @@ class DPPluginBase(Plugin): :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. """ _kwargs = kwargs.copy() - sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle) + distributed_sampler_cls = distributed_sampler_cls or DistributedSampler + sampler = distributed_sampler_cls(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle) # Deterministic dataloader def seed_worker(worker_id): diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index d14109dd4..95b96bbfd 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -456,7 +456,16 @@ class GeminiPlugin(DPPluginBase): return ["cuda", "npu"] def prepare_dataloader( - self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs + self, + dataset, + batch_size, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + distributed_sampler_cls=None, + **kwargs, ): r""" Prepare a dataloader for distributed training. The dataloader will be wrapped by @@ -484,7 +493,8 @@ class GeminiPlugin(DPPluginBase): extra_dp_world_size = self.pg_mesh.size(DP_AXIS) zero_rank = self.pg_mesh.coordinate(ZERO_AXIS) extra_dp_rank = self.pg_mesh.coordinate(DP_AXIS) - sampler = DistributedSampler( + distributed_sampler_cls = distributed_sampler_cls or DistributedSampler + sampler = distributed_sampler_cls( dataset, num_replicas=zero_world_size * extra_dp_world_size, rank=zero_rank * extra_dp_world_size + extra_dp_rank, diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 943e137e6..da67e6b41 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1205,7 +1205,16 @@ class HybridParallelPlugin(PipelinePluginBase): return outputs def prepare_dataloader( - self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs + self, + dataset, + batch_size, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + distributed_sampler_cls=None, + **kwargs, ): r""" Prepare a dataloader for distributed training. The dataloader will be wrapped by @@ -1229,7 +1238,8 @@ class HybridParallelPlugin(PipelinePluginBase): :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. """ _kwargs = kwargs.copy() - sampler = DistributedSampler( + distributed_sampler_cls = distributed_sampler_cls or DistributedSampler + sampler = distributed_sampler_cls( dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle )