replace the customized dataloader setup with the build-in one

pull/5759/head
YeAnbang 2024-06-07 09:43:42 +00:00
parent 790e1362a6
commit 0d7ff10ea5
12 changed files with 79 additions and 218 deletions

View File

@ -5,7 +5,6 @@ from .loader import (
DataCollatorForSupervisedDataset,
StatefulDistributedSampler,
load_tokenized_dataset,
setup_distributed_dataloader,
)
from .tokenization_utils import supervised_tokenize_sft, tokenize_prompt_dataset, tokenize_rlhf
@ -17,7 +16,6 @@ __all__ = [
"DataCollatorForSupervisedDataset",
"StatefulDistributedSampler",
"load_tokenized_dataset",
"setup_distributed_dataloader",
"supervised_tokenize_pretrain",
"supervised_tokenize_sft",
"tokenize_rlhf",

View File

@ -4,22 +4,16 @@
Dataloader for sft, dpo, ppo
"""
import math
import os
import random
from dataclasses import dataclass
from typing import Callable, Dict, Iterator, List, Optional, Sequence, Union
from typing import Dict, Iterator, List, Optional, Sequence, Union
import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
from coati.dataset.utils import chuncate_sequence, pad_to_max_len
from datasets import Dataset as HFDataset
from datasets import dataset_dict, load_from_disk
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import _get_default_group
from torch.utils.data import ConcatDataset, DataLoader, Dataset, DistributedSampler
from torch.utils.data import ConcatDataset, Dataset, DistributedSampler
from transformers.tokenization_utils import PreTrainedTokenizer
DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
@ -236,159 +230,26 @@ class DataCollatorForPreferenceDataset(object):
class StatefulDistributedSampler(DistributedSampler):
"""
Stateful distributed sampler for multi-stage training.
"""
def __init__(
self,
dataset: DatasetType,
dataset: Dataset,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
tp_size: int = 1,
sp_size: int = 1,
pp_size: int = 1,
) -> None:
if not tp_size > 1:
super().__init__(
dataset=dataset,
num_replicas=num_replicas,
rank=rank,
shuffle=shuffle,
seed=seed,
drop_last=drop_last,
)
else:
# adapted from https://github.com/pytorch/pytorch/blob/4979f9c0d72490970e2019bb1d2284f83d93f76b/torch/utils/data/distributed.py#L62
if rank is None:
rank = dist.get_rank()
dist.get_world_size()
# dp_size = world_size // (tp_size * sp_size * pp_size)
dp_rank = int(rank / (tp_size * sp_size * pp_size)) # data parallel rank:
if rank < 0:
raise ValueError(f"Invalid rank {rank}, rank should be in the interval [0, 0]")
self.dataset = dataset
self.num_replicas = num_replicas
self.dp_rank = dp_rank
self.rank = rank
self.epoch = 0
self.drop_last = drop_last
# If the dataset length is evenly divisible by # of replicas, then there
# is no need to drop any data, since the dataset will be split equally.
if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type]
# Split to nearest available length that is evenly divisible.
# This is to ensure each rank receives the same amount of data when
# using this Sampler.
self.num_samples = math.ceil(
(len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
)
else:
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
self.seed = seed
self.start_index = 0
self.tp_size = tp_size
super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last)
self.start_index: int = 0
def __iter__(self) -> Iterator:
if self.tp_size > 1:
# TODO Add support for tp_group not equal to 1
pass
# adpated from https://github.com/pytorch/pytorch/blob/4979f9c0d72490970e2019bb1d2284f83d93f76b/torch/utils/data/distributed.py#L96
if self.shuffle:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
else:
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
if not self.drop_last:
# add extra samples to make it evenly divisible
padding_size = self.total_size - len(indices)
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
else:
# remove tail of data to make it evenly divisible.
indices = indices[: self.total_size]
assert len(indices) == self.total_size
# subsample
indices = indices[
self.dp_rank : self.dp_rank + self.total_size : self.num_replicas
] # num_replicas=tp_group=1, we only support tp_group==1 for now
assert len(indices) == self.num_samples
return iter(indices)
else:
iterator = super().__iter__()
indices = list(iterator)
indices = indices[self.start_index :]
return iter(indices)
iterator = super().__iter__()
indices = list(iterator)
indices = indices[self.start_index :]
return iter(indices)
def __len__(self) -> int:
return self.num_samples - self.start_index
def set_start_index(self, start_index: int) -> None:
self.start_index = start_index
def setup_distributed_dataloader(
dataset: DatasetType,
batch_size: int = 1,
shuffle: bool = False,
seed: int = 1024,
drop_last: bool = False,
pin_memory: bool = False,
num_workers: int = 0,
collate_fn: Callable[[Sequence[Dict[str, Union[str, List[int]]]]], Dict[str, torch.Tensor]] = None,
process_group: Optional[ProcessGroup] = None,
tp_size: Optional[int] = 1,
sp_size: Optional[int] = 1,
pp_size: Optional[int] = 1,
**kwargs,
) -> DataLoader:
"""
Setup dataloader for distributed training.
"""
_kwargs = kwargs.copy()
process_group = process_group or _get_default_group()
# world_size = tp_size * pp_size
assert (
process_group.size() % tp_size == 0
), f"process_group.size()={process_group.size()} must be divisible by tp_size={tp_size}"
sampler = StatefulDistributedSampler(
dataset=dataset,
num_replicas=int(process_group.size() / tp_size),
rank=process_group.rank(),
shuffle=shuffle,
seed=seed,
drop_last=drop_last,
tp_size=tp_size,
sp_size=sp_size,
pp_size=pp_size,
)
# Deterministic dataloader
def seed_worker(worker_id: int) -> None:
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
return DataLoader(
dataset=dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=pin_memory,
drop_last=drop_last,
worker_init_fn=seed_worker,
**_kwargs,
)

View File

@ -55,8 +55,6 @@ def supervised_tokenize_sft(
for mess in messages:
from_str = mess["from"]
if from_str is None:
print(mess)
if from_str.lower() == "human":
from_str = "user"
elif from_str.lower() == "assistant":
@ -133,24 +131,20 @@ def supervised_tokenize_sft(
labels[-1] = tokenizer.eos_token_id
# For some model without bos/eos may raise the following errors
try:
inputs_decode = tokenizer.decode(tokenized)
start = 0
end = 0
label_decode = []
for i in range(len(labels)):
if labels[i] == ignore_index:
if start != end:
label_decode.append(tokenizer.decode(labels[start + 1 : i], skip_special_tokens=False))
start = i
end = i
else:
end = i
if i == len(labels) - 1:
label_decode.append(tokenizer.decode(labels[start + 1 :], skip_special_tokens=False))
except TypeError as e:
raise TypeError(str(e) + f"\nUnable to decode input_ids: {tokenized}")
inputs_decode = tokenizer.decode(tokenized)
start = 0
end = 0
label_decode = []
for i in range(len(labels)):
if labels[i] == ignore_index:
if start != end:
label_decode.append(tokenizer.decode(labels[start + 1 : i], skip_special_tokens=False))
start = i
end = i
else:
end = i
if i == len(labels) - 1:
label_decode.append(tokenizer.decode(labels[start + 1 :], skip_special_tokens=False))
# Check if all labels are ignored, this may happen when the tokenized length is too long
if labels.count(ignore_index) == len(labels):

View File

@ -2,7 +2,11 @@
"chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
"stop_ids": [
2
31007,
326,
30962,
437,
31007
],
"end_of_assistant": "<|im_end|>"
}

View File

@ -29,6 +29,7 @@
- [Alternative Option For RLHF: Direct Preference Optimization](#alternative-option-for-rlhf-direct-preference-optimization)
- [DPO Stage 1: Supervised Instruction Tuning](#dpo-training-stage1---supervised-instructs-tuning)
- [DPO Stage 2: DPO Training](#dpo-training-stage2---dpo-training)
- [List of Supported Models](#)
- [Hardware Requirements](#hardware-requirements)
- [Inference example](#inference-example)
- [Attention](#attention)
@ -744,6 +745,26 @@ For DPO, we recommend using zero2 or zero2-cpu. We tested the VRAM consumption o
- 4 H800 GPUs
- zero2, batch size=4, VRAM Usage=67544.47 MB
## List of Supported Models
For SFT, we support the following models/series:
- Colossal-LLaMA-2
- ChatGLM2
- ChatGLM3 (only with zero2, zero2_cpu plugin)
- Baichuan2
- LLaMA2
- Qwen1.5-7B-Chat (with transformers==4.39.1)
- Yi-1.5
For PPO and DPO, we theoratically support the following models/series (without guarantee):
- Colossal-LLaMA-2 (tested)
- ChatGLM2
- Baichuan2
- LLaMA2 (tested)
- Qwen1.5-7B-Chat (with transformers==4.39.1)
- Yi-1.5
*-* The zero2, zero2_cpu plugin also support a wide range of chat models not listed above.
## Inference example

View File

@ -1 +1,5 @@
172.27.183.199
XXX.XX.XXX.XXX # Your master IP
XXX.XX.XXX.XXX # Your slave IPs
XXX.XX.XXX.XXX # Your slave IPs
XXX.XX.XXX.XXX # Your slave IPs
XXX.XX.XXX.XXX # Your slave IPs

View File

@ -5,12 +5,7 @@ import resource
from contextlib import nullcontext
import torch
from coati.dataset import (
DataCollatorForPreferenceDataset,
StatefulDistributedSampler,
load_tokenized_dataset,
setup_distributed_dataloader,
)
from coati.dataset import DataCollatorForPreferenceDataset, StatefulDistributedSampler, load_tokenized_dataset
from coati.models import convert_to_lora_module, disable_dropout
from coati.trainer import DPOTrainer
from coati.utils import load_checkpoint
@ -174,15 +169,14 @@ def train(args):
mode_map = {"train": "train", "valid": "validation", "test": "test"}
train_dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train", mode_map=mode_map)
data_collator = DataCollatorForPreferenceDataset(tokenizer=tokenizer, max_length=args.max_length)
train_dataloader = setup_distributed_dataloader(
train_dataloader = plugin.prepare_dataloader(
dataset=train_dataset,
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
collate_fn=data_collator,
tp_size=plugin.tp_size if hasattr(plugin, "tp_size") else 1,
sp_size=plugin.sp_size if hasattr(plugin, "sp_size") else 1,
pp_size=plugin.pp_size if hasattr(plugin, "pp_size") else 1,
distributed_sampler_cls=StatefulDistributedSampler,
)
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps

View File

@ -12,7 +12,6 @@ from coati.dataset import (
StatefulDistributedSampler,
load_tokenized_dataset,
setup_conversation_template,
setup_distributed_dataloader,
)
from coati.models import Critic, RewardModel, convert_to_lora_module, disable_dropout
from coati.trainer import PPOTrainer
@ -209,6 +208,9 @@ def train(args):
max_norm=args.grad_clip,
)
elif args.plugin == "3d":
if args.use_flash_attn and (args.tp > 1 or args.pp > 1 or args.sp > 1 or args.enable_sequence_parallelism):
logger.warning("Flash attention cannot be used with 3D parallelism for PPO training. Disabling it.")
args.use_flash_attn = False
plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=args.pp,
@ -247,29 +249,26 @@ def train(args):
mode_map = {"train": "train", "valid": "validation", "test": "test"}
train_prompt_dataset = load_tokenized_dataset(dataset_paths=args.prompt_dataset, mode="train", mode_map=mode_map)
data_collator = DataCollatorForPromptDataset(tokenizer=tokenizer, max_length=args.max_length - args.max_seq_len)
train_prompt_dataloader = setup_distributed_dataloader(
train_prompt_dataloader = plugin.prepare_dataloader(
dataset=train_prompt_dataset,
batch_size=args.experience_batch_size,
shuffle=True,
drop_last=True,
collate_fn=data_collator,
tp_size=plugin.tp_size if hasattr(plugin, "tp_size") else 1,
sp_size=plugin.sp_size if hasattr(plugin, "sp_size") else 1,
pp_size=plugin.pp_size if hasattr(plugin, "pp_size") else 1,
distributed_sampler_cls=StatefulDistributedSampler,
)
if len(args.ptx_dataset) > 0:
train_ptx_dataset = load_tokenized_dataset(dataset_paths=args.ptx_dataset, mode="train", mode_map=mode_map)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
train_pretrain_dataloader = setup_distributed_dataloader(
train_pretrain_dataloader = plugin.prepare_dataloader(
dataset=train_ptx_dataset,
batch_size=args.ptx_batch_size,
shuffle=True,
drop_last=True,
collate_fn=data_collator,
tp_size=plugin.tp_size if hasattr(plugin, "tp_size") else 1,
sp_size=plugin.sp_size if hasattr(plugin, "sp_size") else 1,
pp_size=plugin.pp_size if hasattr(plugin, "pp_size") else 1,
distributed_sampler_cls=StatefulDistributedSampler,
)
else:
train_pretrain_dataloader = None

View File

@ -6,12 +6,7 @@ import resource
from contextlib import nullcontext
import torch
from coati.dataset import (
DataCollatorForPreferenceDataset,
StatefulDistributedSampler,
load_tokenized_dataset,
setup_distributed_dataloader,
)
from coati.dataset import DataCollatorForPreferenceDataset, StatefulDistributedSampler, load_tokenized_dataset
from coati.models import LogExpLoss, LogSigLoss, RewardModel, convert_to_lora_module
from coati.trainer import RewardModelTrainer
from coati.utils import load_checkpoint
@ -169,17 +164,15 @@ def train(args):
mode_map = {"train": "train", "valid": "validation", "test": "test"}
train_dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train", mode_map=mode_map)
data_collator = DataCollatorForPreferenceDataset(tokenizer=tokenizer, max_length=args.max_length)
train_dataloader = setup_distributed_dataloader(
train_dataloader = plugin.prepare_dataloader(
dataset=train_dataset,
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
collate_fn=data_collator,
tp_size=plugin.tp_size if hasattr(plugin, "tp_size") else 1,
sp_size=plugin.sp_size if hasattr(plugin, "sp_size") else 1,
pp_size=plugin.pp_size if hasattr(plugin, "pp_size") else 1,
distributed_sampler_cls=StatefulDistributedSampler,
)
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
math.ceil(args.max_epochs * num_update_steps_per_epoch)

View File

@ -8,7 +8,7 @@ import sys
from contextlib import nullcontext
import torch
from coati.dataset import DataCollatorForSupervisedDataset, load_tokenized_dataset, setup_distributed_dataloader
from coati.dataset import DataCollatorForSupervisedDataset, StatefulDistributedSampler, load_tokenized_dataset
from coati.models import convert_to_lora_module
from coati.trainer import SFTTrainer
from coati.utils import load_checkpoint
@ -189,21 +189,15 @@ def train(args):
)
dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train")
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_len)
train_dataloader = setup_distributed_dataloader(
train_dataloader = plugin.prepare_dataloader(
dataset=dataset,
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
collate_fn=data_collator,
tp_size=plugin.tp_size if hasattr(plugin, "tp_size") else 1,
sp_size=plugin.sp_size if hasattr(plugin, "sp_size") else 1,
pp_size=plugin.pp_size if hasattr(plugin, "pp_size") else 1,
distributed_sampler_cls=StatefulDistributedSampler,
)
# print(len(train_dataloader))
# for batch in train_dataloader:
# print(dist.get_rank(), tokenizer.batch_decode(batch["input_ids"]))
# break
coordinator.print_on_master(
f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
)

View File

@ -6,8 +6,8 @@ TEST_DATA_DIR=$BASE_DIR/tests/test_data
DATA_SAVE_PATH=$BASE_TEMP_DIR/tests
CONFIG_DIR=$BASE_DIR/config
# MODELS=("colossal-llama2" "llama2" "mistral" "chatGLM2" "chatGLM3" "deepseek" "Yi" "baichuan")
MODELS=("colossal-llama2" "llama2" "chatGLM2" "chatGLM3" "deepseek" "Yi")
MODELS=("colossal-llama2" "llama2" "mistral" "chatGLM2" "chatGLM3" "deepseek" "Yi" "baichuan") # for local test
# MODELS=("colossal-llama2" "llama2" "chatGLM2" "chatGLM3" "deepseek" "Yi")
get_pretrain() {
local model=$1

View File

@ -30,8 +30,7 @@ MODEL_SAVE_PATH=$TEMP_DIR/rlhf_models
MODELS_DIR=$TEMP_DIR/models_config
# Skip those tests due to CI tests timeout
MODELS=('llama')
# ADVANCED_PLUGINS=('pp' 'tp_zero2' 'tp_pp' 'sp_split_gather' 'sp_ring' 'sp_all_to_all' '3d' 'gemini' 'gemini_auto' 'zero2' 'zero2_cpu')
ADVANCED_PLUGINS=('tp_zero2' '3d' 'gemini' 'gemini_auto' 'zero2' 'zero2_cpu')
ADVANCED_PLUGINS=('sp_split_gather' 'sp_ring' 'sp_all_to_all' 'tp_zero2' '3d' 'gemini' 'gemini_auto' 'zero2' 'zero2_cpu') # pp is still buggy
PLUGINS=('3d' 'gemini' 'gemini_auto' 'zero2' 'zero2_cpu')
LORA_RANK=('0') # skip to reduce CI execution time, can pass all locally
@ -281,7 +280,7 @@ echo "[Test]: testing ppo ..."
SKIPPED_TESTS=(
# llama-3d # 3d plugin doesn't support lora
llama-3d # 3d plugin doesn't support lora
llama-gemini # gemini doesn't support lora
)
@ -359,7 +358,7 @@ for lora_rank in ${LORA_RANK[@]}; do
$grad_ckpt \
--max_len 400 \
--max_seq_len 10 \
--use_flash_attn
# --use_flash_attn
passed=$?
if [ $passed -eq 0 ]; then
rm -rf $MODEL_SAVE_PATH/*