mirror of https://github.com/hpcaitech/ColossalAI
[llama] fix dataloader for hybrid parallel (#5358)
* [plugin] refactor prepare dataloader * [plugin] update train scriptpull/5362/head
parent
2dd01e3a14
commit
6c0fa7b9a8
|
@ -1,20 +1,16 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import os
|
import os
|
||||||
import random
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Union, Sequence, Optional, Iterator, Callable
|
from typing import Dict, Iterator, List, Optional, Sequence, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import dataset_dict, load_from_disk
|
|
||||||
from datasets import Dataset as HFDataset
|
|
||||||
from torch.distributed import ProcessGroup
|
|
||||||
from torch.distributed.distributed_c10d import _get_default_group
|
|
||||||
from torch.utils.data import ConcatDataset, Dataset, DataLoader, DistributedSampler
|
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from datasets import Dataset as HFDataset
|
||||||
|
from datasets import dataset_dict, load_from_disk
|
||||||
|
from torch.utils.data import ConcatDataset, Dataset, DistributedSampler
|
||||||
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
|
|
||||||
DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
|
DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
|
||||||
PathType = Union[str, os.PathLike]
|
PathType = Union[str, os.PathLike]
|
||||||
|
@ -171,49 +167,3 @@ class StatefulDistributedSampler(DistributedSampler):
|
||||||
|
|
||||||
def set_start_index(self, start_index: int) -> None:
|
def set_start_index(self, start_index: int) -> None:
|
||||||
self.start_index = start_index
|
self.start_index = start_index
|
||||||
|
|
||||||
|
|
||||||
def setup_distributed_dataloader(
|
|
||||||
dataset: DatasetType,
|
|
||||||
batch_size: int = 1,
|
|
||||||
shuffle: bool = False,
|
|
||||||
seed: int = 1024,
|
|
||||||
drop_last: bool = False,
|
|
||||||
pin_memory: bool = False,
|
|
||||||
num_workers: int = 0,
|
|
||||||
collate_fn: Callable[[Sequence[Dict[str, Union[str, List[int]]]]], Dict[str, torch.Tensor]] = None,
|
|
||||||
process_group: Optional[ProcessGroup] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> DataLoader:
|
|
||||||
"""
|
|
||||||
Setup dataloader for distributed training.
|
|
||||||
"""
|
|
||||||
_kwargs = kwargs.copy()
|
|
||||||
process_group = process_group or _get_default_group()
|
|
||||||
sampler = StatefulDistributedSampler(
|
|
||||||
dataset=dataset,
|
|
||||||
num_replicas=process_group.size(),
|
|
||||||
rank=process_group.rank(),
|
|
||||||
shuffle=shuffle,
|
|
||||||
seed=seed,
|
|
||||||
drop_last=drop_last,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Deterministic dataloader
|
|
||||||
def seed_worker(worker_id: int) -> None:
|
|
||||||
worker_seed = seed
|
|
||||||
np.random.seed(worker_seed)
|
|
||||||
torch.manual_seed(worker_seed)
|
|
||||||
random.seed(worker_seed)
|
|
||||||
|
|
||||||
return DataLoader(
|
|
||||||
dataset=dataset,
|
|
||||||
batch_size=batch_size,
|
|
||||||
sampler=sampler,
|
|
||||||
num_workers=num_workers,
|
|
||||||
collate_fn=collate_fn,
|
|
||||||
pin_memory=pin_memory,
|
|
||||||
drop_last=drop_last,
|
|
||||||
worker_init_fn=seed_worker,
|
|
||||||
**_kwargs,
|
|
||||||
)
|
|
||||||
|
|
|
@ -16,7 +16,6 @@ from colossal_llama2.dataset.loader import (
|
||||||
DataCollatorForSupervisedDataset,
|
DataCollatorForSupervisedDataset,
|
||||||
StatefulDistributedSampler,
|
StatefulDistributedSampler,
|
||||||
load_tokenized_dataset,
|
load_tokenized_dataset,
|
||||||
setup_distributed_dataloader,
|
|
||||||
)
|
)
|
||||||
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
|
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
|
||||||
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
|
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
|
||||||
|
@ -194,12 +193,13 @@ def main() -> None:
|
||||||
|
|
||||||
dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train")
|
dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train")
|
||||||
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
|
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
|
||||||
dataloader = setup_distributed_dataloader(
|
dataloader = plugin.prepare_dataloader(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
batch_size=args.micro_batch_size,
|
batch_size=args.micro_batch_size,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
collate_fn=data_collator,
|
collate_fn=data_collator,
|
||||||
|
distributed_sampler_cls=StatefulDistributedSampler,
|
||||||
)
|
)
|
||||||
coordinator.print_on_master(
|
coordinator.print_on_master(
|
||||||
f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||||
|
|
|
@ -16,7 +16,6 @@ from colossal_llama2.dataset.loader import (
|
||||||
DataCollatorForSupervisedDataset,
|
DataCollatorForSupervisedDataset,
|
||||||
StatefulDistributedSampler,
|
StatefulDistributedSampler,
|
||||||
load_tokenized_dataset,
|
load_tokenized_dataset,
|
||||||
setup_distributed_dataloader,
|
|
||||||
)
|
)
|
||||||
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
|
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
|
||||||
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
|
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
|
||||||
|
@ -203,12 +202,13 @@ def main() -> None:
|
||||||
|
|
||||||
dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train")
|
dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train")
|
||||||
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
|
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
|
||||||
dataloader = setup_distributed_dataloader(
|
dataloader = plugin.prepare_dataloader(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
batch_size=args.micro_batch_size,
|
batch_size=args.micro_batch_size,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
collate_fn=data_collator,
|
collate_fn=data_collator,
|
||||||
|
distributed_sampler_cls=StatefulDistributedSampler,
|
||||||
)
|
)
|
||||||
coordinator.print_on_master(
|
coordinator.print_on_master(
|
||||||
f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||||
|
|
|
@ -21,7 +21,16 @@ class DPPluginBase(Plugin):
|
||||||
self.world_size = dist.get_world_size()
|
self.world_size = dist.get_world_size()
|
||||||
|
|
||||||
def prepare_dataloader(
|
def prepare_dataloader(
|
||||||
self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
|
self,
|
||||||
|
dataset,
|
||||||
|
batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
seed=1024,
|
||||||
|
drop_last=False,
|
||||||
|
pin_memory=False,
|
||||||
|
num_workers=0,
|
||||||
|
distributed_sampler_cls=None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
||||||
|
@ -45,7 +54,8 @@ class DPPluginBase(Plugin):
|
||||||
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
|
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
|
||||||
"""
|
"""
|
||||||
_kwargs = kwargs.copy()
|
_kwargs = kwargs.copy()
|
||||||
sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle)
|
distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
|
||||||
|
sampler = distributed_sampler_cls(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle)
|
||||||
|
|
||||||
# Deterministic dataloader
|
# Deterministic dataloader
|
||||||
def seed_worker(worker_id):
|
def seed_worker(worker_id):
|
||||||
|
|
|
@ -456,7 +456,16 @@ class GeminiPlugin(DPPluginBase):
|
||||||
return ["cuda", "npu"]
|
return ["cuda", "npu"]
|
||||||
|
|
||||||
def prepare_dataloader(
|
def prepare_dataloader(
|
||||||
self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
|
self,
|
||||||
|
dataset,
|
||||||
|
batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
seed=1024,
|
||||||
|
drop_last=False,
|
||||||
|
pin_memory=False,
|
||||||
|
num_workers=0,
|
||||||
|
distributed_sampler_cls=None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
||||||
|
@ -484,7 +493,8 @@ class GeminiPlugin(DPPluginBase):
|
||||||
extra_dp_world_size = self.pg_mesh.size(DP_AXIS)
|
extra_dp_world_size = self.pg_mesh.size(DP_AXIS)
|
||||||
zero_rank = self.pg_mesh.coordinate(ZERO_AXIS)
|
zero_rank = self.pg_mesh.coordinate(ZERO_AXIS)
|
||||||
extra_dp_rank = self.pg_mesh.coordinate(DP_AXIS)
|
extra_dp_rank = self.pg_mesh.coordinate(DP_AXIS)
|
||||||
sampler = DistributedSampler(
|
distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
|
||||||
|
sampler = distributed_sampler_cls(
|
||||||
dataset,
|
dataset,
|
||||||
num_replicas=zero_world_size * extra_dp_world_size,
|
num_replicas=zero_world_size * extra_dp_world_size,
|
||||||
rank=zero_rank * extra_dp_world_size + extra_dp_rank,
|
rank=zero_rank * extra_dp_world_size + extra_dp_rank,
|
||||||
|
|
|
@ -1205,7 +1205,16 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def prepare_dataloader(
|
def prepare_dataloader(
|
||||||
self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
|
self,
|
||||||
|
dataset,
|
||||||
|
batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
seed=1024,
|
||||||
|
drop_last=False,
|
||||||
|
pin_memory=False,
|
||||||
|
num_workers=0,
|
||||||
|
distributed_sampler_cls=None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
||||||
|
@ -1229,7 +1238,8 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
|
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
|
||||||
"""
|
"""
|
||||||
_kwargs = kwargs.copy()
|
_kwargs = kwargs.copy()
|
||||||
sampler = DistributedSampler(
|
distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
|
||||||
|
sampler = distributed_sampler_cls(
|
||||||
dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle
|
dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue