[llama] fix dataloader for hybrid parallel (#5358)

* [plugin] refactor prepare dataloader

* [plugin] update train script
pull/5362/head
Hongxin Liu 2024-02-05 15:14:56 +08:00 committed by GitHub
parent 2dd01e3a14
commit 6c0fa7b9a8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 45 additions and 65 deletions

View File

@ -1,20 +1,16 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import numpy as np
import os
import random
from dataclasses import dataclass
from typing import Dict, List, Union, Sequence, Optional, Iterator, Callable
from typing import Dict, Iterator, List, Optional, Sequence, Union
import torch
from datasets import dataset_dict, load_from_disk
from datasets import Dataset as HFDataset
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import _get_default_group
from torch.utils.data import ConcatDataset, Dataset, DataLoader, DistributedSampler
from transformers.tokenization_utils import PreTrainedTokenizer
import torch.nn.functional as F
from datasets import Dataset as HFDataset
from datasets import dataset_dict, load_from_disk
from torch.utils.data import ConcatDataset, Dataset, DistributedSampler
from transformers.tokenization_utils import PreTrainedTokenizer
DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
PathType = Union[str, os.PathLike]
@ -171,49 +167,3 @@ class StatefulDistributedSampler(DistributedSampler):
def set_start_index(self, start_index: int) -> None:
self.start_index = start_index
def setup_distributed_dataloader(
dataset: DatasetType,
batch_size: int = 1,
shuffle: bool = False,
seed: int = 1024,
drop_last: bool = False,
pin_memory: bool = False,
num_workers: int = 0,
collate_fn: Callable[[Sequence[Dict[str, Union[str, List[int]]]]], Dict[str, torch.Tensor]] = None,
process_group: Optional[ProcessGroup] = None,
**kwargs,
) -> DataLoader:
"""
Setup dataloader for distributed training.
"""
_kwargs = kwargs.copy()
process_group = process_group or _get_default_group()
sampler = StatefulDistributedSampler(
dataset=dataset,
num_replicas=process_group.size(),
rank=process_group.rank(),
shuffle=shuffle,
seed=seed,
drop_last=drop_last,
)
# Deterministic dataloader
def seed_worker(worker_id: int) -> None:
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
return DataLoader(
dataset=dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=pin_memory,
drop_last=drop_last,
worker_init_fn=seed_worker,
**_kwargs,
)

View File

@ -16,7 +16,6 @@ from colossal_llama2.dataset.loader import (
DataCollatorForSupervisedDataset,
StatefulDistributedSampler,
load_tokenized_dataset,
setup_distributed_dataloader,
)
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
@ -194,12 +193,13 @@ def main() -> None:
dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train")
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
dataloader = setup_distributed_dataloader(
dataloader = plugin.prepare_dataloader(
dataset=dataset,
batch_size=args.micro_batch_size,
shuffle=True,
drop_last=True,
collate_fn=data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
)
coordinator.print_on_master(
f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"

View File

@ -16,7 +16,6 @@ from colossal_llama2.dataset.loader import (
DataCollatorForSupervisedDataset,
StatefulDistributedSampler,
load_tokenized_dataset,
setup_distributed_dataloader,
)
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
@ -203,12 +202,13 @@ def main() -> None:
dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train")
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
dataloader = setup_distributed_dataloader(
dataloader = plugin.prepare_dataloader(
dataset=dataset,
batch_size=args.micro_batch_size,
shuffle=True,
drop_last=True,
collate_fn=data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
)
coordinator.print_on_master(
f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"

View File

@ -21,7 +21,16 @@ class DPPluginBase(Plugin):
self.world_size = dist.get_world_size()
def prepare_dataloader(
self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
self,
dataset,
batch_size,
shuffle=False,
seed=1024,
drop_last=False,
pin_memory=False,
num_workers=0,
distributed_sampler_cls=None,
**kwargs,
):
r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
@ -45,7 +54,8 @@ class DPPluginBase(Plugin):
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
"""
_kwargs = kwargs.copy()
sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle)
distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
sampler = distributed_sampler_cls(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle)
# Deterministic dataloader
def seed_worker(worker_id):

View File

@ -456,7 +456,16 @@ class GeminiPlugin(DPPluginBase):
return ["cuda", "npu"]
def prepare_dataloader(
self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
self,
dataset,
batch_size,
shuffle=False,
seed=1024,
drop_last=False,
pin_memory=False,
num_workers=0,
distributed_sampler_cls=None,
**kwargs,
):
r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
@ -484,7 +493,8 @@ class GeminiPlugin(DPPluginBase):
extra_dp_world_size = self.pg_mesh.size(DP_AXIS)
zero_rank = self.pg_mesh.coordinate(ZERO_AXIS)
extra_dp_rank = self.pg_mesh.coordinate(DP_AXIS)
sampler = DistributedSampler(
distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
sampler = distributed_sampler_cls(
dataset,
num_replicas=zero_world_size * extra_dp_world_size,
rank=zero_rank * extra_dp_world_size + extra_dp_rank,

View File

@ -1205,7 +1205,16 @@ class HybridParallelPlugin(PipelinePluginBase):
return outputs
def prepare_dataloader(
self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
self,
dataset,
batch_size,
shuffle=False,
seed=1024,
drop_last=False,
pin_memory=False,
num_workers=0,
distributed_sampler_cls=None,
**kwargs,
):
r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
@ -1229,7 +1238,8 @@ class HybridParallelPlugin(PipelinePluginBase):
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
"""
_kwargs = kwargs.copy()
sampler = DistributedSampler(
distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
sampler = distributed_sampler_cls(
dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle
)