mirror of https://github.com/hpcaitech/ColossalAI
[llama] fix dataloader for hybrid parallel (#5358)
* [plugin] refactor prepare dataloader * [plugin] update train scriptpull/5362/head
parent
2dd01e3a14
commit
6c0fa7b9a8
|
@ -1,20 +1,16 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Union, Sequence, Optional, Iterator, Callable
|
||||
from typing import Dict, Iterator, List, Optional, Sequence, Union
|
||||
|
||||
import torch
|
||||
from datasets import dataset_dict, load_from_disk
|
||||
from datasets import Dataset as HFDataset
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
from torch.utils.data import ConcatDataset, Dataset, DataLoader, DistributedSampler
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
import torch.nn.functional as F
|
||||
from datasets import Dataset as HFDataset
|
||||
from datasets import dataset_dict, load_from_disk
|
||||
from torch.utils.data import ConcatDataset, Dataset, DistributedSampler
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
|
||||
PathType = Union[str, os.PathLike]
|
||||
|
@ -171,49 +167,3 @@ class StatefulDistributedSampler(DistributedSampler):
|
|||
|
||||
def set_start_index(self, start_index: int) -> None:
|
||||
self.start_index = start_index
|
||||
|
||||
|
||||
def setup_distributed_dataloader(
|
||||
dataset: DatasetType,
|
||||
batch_size: int = 1,
|
||||
shuffle: bool = False,
|
||||
seed: int = 1024,
|
||||
drop_last: bool = False,
|
||||
pin_memory: bool = False,
|
||||
num_workers: int = 0,
|
||||
collate_fn: Callable[[Sequence[Dict[str, Union[str, List[int]]]]], Dict[str, torch.Tensor]] = None,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
**kwargs,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Setup dataloader for distributed training.
|
||||
"""
|
||||
_kwargs = kwargs.copy()
|
||||
process_group = process_group or _get_default_group()
|
||||
sampler = StatefulDistributedSampler(
|
||||
dataset=dataset,
|
||||
num_replicas=process_group.size(),
|
||||
rank=process_group.rank(),
|
||||
shuffle=shuffle,
|
||||
seed=seed,
|
||||
drop_last=drop_last,
|
||||
)
|
||||
|
||||
# Deterministic dataloader
|
||||
def seed_worker(worker_id: int) -> None:
|
||||
worker_seed = seed
|
||||
np.random.seed(worker_seed)
|
||||
torch.manual_seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
return DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
num_workers=num_workers,
|
||||
collate_fn=collate_fn,
|
||||
pin_memory=pin_memory,
|
||||
drop_last=drop_last,
|
||||
worker_init_fn=seed_worker,
|
||||
**_kwargs,
|
||||
)
|
||||
|
|
|
@ -16,7 +16,6 @@ from colossal_llama2.dataset.loader import (
|
|||
DataCollatorForSupervisedDataset,
|
||||
StatefulDistributedSampler,
|
||||
load_tokenized_dataset,
|
||||
setup_distributed_dataloader,
|
||||
)
|
||||
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
|
||||
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
|
||||
|
@ -194,12 +193,13 @@ def main() -> None:
|
|||
|
||||
dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train")
|
||||
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
|
||||
dataloader = setup_distributed_dataloader(
|
||||
dataloader = plugin.prepare_dataloader(
|
||||
dataset=dataset,
|
||||
batch_size=args.micro_batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=data_collator,
|
||||
distributed_sampler_cls=StatefulDistributedSampler,
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
|
|
|
@ -16,7 +16,6 @@ from colossal_llama2.dataset.loader import (
|
|||
DataCollatorForSupervisedDataset,
|
||||
StatefulDistributedSampler,
|
||||
load_tokenized_dataset,
|
||||
setup_distributed_dataloader,
|
||||
)
|
||||
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
|
||||
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
|
||||
|
@ -203,12 +202,13 @@ def main() -> None:
|
|||
|
||||
dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train")
|
||||
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
|
||||
dataloader = setup_distributed_dataloader(
|
||||
dataloader = plugin.prepare_dataloader(
|
||||
dataset=dataset,
|
||||
batch_size=args.micro_batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=data_collator,
|
||||
distributed_sampler_cls=StatefulDistributedSampler,
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
|
|
|
@ -21,7 +21,16 @@ class DPPluginBase(Plugin):
|
|||
self.world_size = dist.get_world_size()
|
||||
|
||||
def prepare_dataloader(
|
||||
self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
|
||||
self,
|
||||
dataset,
|
||||
batch_size,
|
||||
shuffle=False,
|
||||
seed=1024,
|
||||
drop_last=False,
|
||||
pin_memory=False,
|
||||
num_workers=0,
|
||||
distributed_sampler_cls=None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
||||
|
@ -45,7 +54,8 @@ class DPPluginBase(Plugin):
|
|||
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
|
||||
"""
|
||||
_kwargs = kwargs.copy()
|
||||
sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle)
|
||||
distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
|
||||
sampler = distributed_sampler_cls(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle)
|
||||
|
||||
# Deterministic dataloader
|
||||
def seed_worker(worker_id):
|
||||
|
|
|
@ -456,7 +456,16 @@ class GeminiPlugin(DPPluginBase):
|
|||
return ["cuda", "npu"]
|
||||
|
||||
def prepare_dataloader(
|
||||
self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
|
||||
self,
|
||||
dataset,
|
||||
batch_size,
|
||||
shuffle=False,
|
||||
seed=1024,
|
||||
drop_last=False,
|
||||
pin_memory=False,
|
||||
num_workers=0,
|
||||
distributed_sampler_cls=None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
||||
|
@ -484,7 +493,8 @@ class GeminiPlugin(DPPluginBase):
|
|||
extra_dp_world_size = self.pg_mesh.size(DP_AXIS)
|
||||
zero_rank = self.pg_mesh.coordinate(ZERO_AXIS)
|
||||
extra_dp_rank = self.pg_mesh.coordinate(DP_AXIS)
|
||||
sampler = DistributedSampler(
|
||||
distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
|
||||
sampler = distributed_sampler_cls(
|
||||
dataset,
|
||||
num_replicas=zero_world_size * extra_dp_world_size,
|
||||
rank=zero_rank * extra_dp_world_size + extra_dp_rank,
|
||||
|
|
|
@ -1205,7 +1205,16 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
return outputs
|
||||
|
||||
def prepare_dataloader(
|
||||
self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
|
||||
self,
|
||||
dataset,
|
||||
batch_size,
|
||||
shuffle=False,
|
||||
seed=1024,
|
||||
drop_last=False,
|
||||
pin_memory=False,
|
||||
num_workers=0,
|
||||
distributed_sampler_cls=None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
||||
|
@ -1229,7 +1238,8 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
|
||||
"""
|
||||
_kwargs = kwargs.copy()
|
||||
sampler = DistributedSampler(
|
||||
distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
|
||||
sampler = distributed_sampler_cls(
|
||||
dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue