mirror of https://github.com/hpcaitech/ColossalAI
Merge pull request #5759 from hpcaitech/colossalchat_upgrade
[ColossalChat] Colossalchat upgradepull/5793/head
commit
74f4a29734
|
@ -4,10 +4,11 @@ on:
|
|||
pull_request:
|
||||
types: [synchronize, opened, reopened]
|
||||
paths:
|
||||
- "applications/Chat/coati/**"
|
||||
- "applications/Chat/requirements.txt"
|
||||
- "applications/Chat/setup.py"
|
||||
- "applications/Chat/examples/**"
|
||||
- "applications/ColossalChat/coati/**"
|
||||
- "applications/ColossalChat/requirements.txt"
|
||||
- "applications/ColossalChat/setup.py"
|
||||
- "applications/ColossalChat/examples/**"
|
||||
- "applications/ColossalChat/tests/**"
|
||||
|
||||
jobs:
|
||||
tests:
|
||||
|
@ -41,7 +42,7 @@ jobs:
|
|||
|
||||
- name: Install Transformers
|
||||
run: |
|
||||
pip install transformers==4.34.1
|
||||
pip install transformers==4.36.2
|
||||
|
||||
- name: Execute Examples
|
||||
run: |
|
||||
|
|
|
@ -4,12 +4,11 @@ on:
|
|||
pull_request:
|
||||
types: [synchronize, opened, reopened]
|
||||
paths:
|
||||
- 'applications/Chat/coati/**'
|
||||
- 'applications/Chat/requirements.txt'
|
||||
- 'applications/Chat/setup.py'
|
||||
- 'applications/Chat/requirements-test.txt'
|
||||
- 'applications/Chat/tests/**'
|
||||
- 'applications/Chat/pytest.ini'
|
||||
- 'applications/ColossalChat/coati/**'
|
||||
- 'applications/ColossalChat/requirements.txt'
|
||||
- 'applications/ColossalChat/setup.py'
|
||||
- 'applications/ColossalChat/tests/**'
|
||||
- 'applications/ColossalChat/pytest.ini'
|
||||
|
||||
jobs:
|
||||
tests:
|
||||
|
|
|
@ -5,7 +5,6 @@ from .loader import (
|
|||
DataCollatorForSupervisedDataset,
|
||||
StatefulDistributedSampler,
|
||||
load_tokenized_dataset,
|
||||
setup_distributed_dataloader,
|
||||
)
|
||||
from .tokenization_utils import supervised_tokenize_sft, tokenize_prompt_dataset, tokenize_rlhf
|
||||
|
||||
|
@ -17,7 +16,6 @@ __all__ = [
|
|||
"DataCollatorForSupervisedDataset",
|
||||
"StatefulDistributedSampler",
|
||||
"load_tokenized_dataset",
|
||||
"setup_distributed_dataloader",
|
||||
"supervised_tokenize_pretrain",
|
||||
"supervised_tokenize_sft",
|
||||
"tokenize_rlhf",
|
||||
|
|
|
@ -17,6 +17,7 @@ class Conversation:
|
|||
system_message: str
|
||||
chat_template: str
|
||||
stop_ids: List[int]
|
||||
end_of_assistant: str
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, tokenizer: PreTrainedTokenizer, config: Dict):
|
||||
|
@ -24,7 +25,9 @@ class Conversation:
|
|||
Setup the conversation template from config
|
||||
"""
|
||||
tokenizer.chat_template = config["chat_template"]
|
||||
conv = cls(tokenizer, config["system_message"], config["chat_template"], config["stop_ids"])
|
||||
conv = cls(
|
||||
tokenizer, config["system_message"], config["chat_template"], config["stop_ids"], config["end_of_assistant"]
|
||||
)
|
||||
conv.clear()
|
||||
return conv
|
||||
|
||||
|
@ -109,6 +112,8 @@ def setup_conversation_template(
|
|||
"""
|
||||
if any([s not in chat_template_config.keys() for s in Conversation.get_conversation_template_keys()]):
|
||||
# Try to automatically set up conversation template, if fail, it throws an error that you need to do it manually
|
||||
if "end_of_assistant" not in chat_template_config:
|
||||
raise ValueError("Please set the end of assistant token.")
|
||||
if "system_message" not in chat_template_config:
|
||||
logger.warning("No system message is provided, will not use system message.")
|
||||
if "chat_template" not in chat_template_config:
|
||||
|
|
|
@ -4,22 +4,16 @@
|
|||
Dataloader for sft, dpo, ppo
|
||||
"""
|
||||
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, Iterator, List, Optional, Sequence, Union
|
||||
from typing import Dict, Iterator, List, Optional, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from coati.dataset.utils import chuncate_sequence, pad_to_max_len
|
||||
from datasets import Dataset as HFDataset
|
||||
from datasets import dataset_dict, load_from_disk
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
from torch.utils.data import ConcatDataset, DataLoader, Dataset, DistributedSampler
|
||||
from torch.utils.data import ConcatDataset, Dataset, DistributedSampler
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
|
||||
|
@ -236,148 +230,26 @@ class DataCollatorForPreferenceDataset(object):
|
|||
|
||||
|
||||
class StatefulDistributedSampler(DistributedSampler):
|
||||
"""
|
||||
Stateful distributed sampler for multi-stage training.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset: DatasetType,
|
||||
dataset: Dataset,
|
||||
num_replicas: Optional[int] = None,
|
||||
rank: Optional[int] = None,
|
||||
shuffle: bool = True,
|
||||
seed: int = 0,
|
||||
drop_last: bool = False,
|
||||
use_tp: Optional[bool] = False,
|
||||
) -> None:
|
||||
if not use_tp:
|
||||
super().__init__(
|
||||
dataset=dataset,
|
||||
num_replicas=num_replicas,
|
||||
rank=rank,
|
||||
shuffle=shuffle,
|
||||
seed=seed,
|
||||
drop_last=drop_last,
|
||||
)
|
||||
else:
|
||||
# adapted from https://github.com/pytorch/pytorch/blob/4979f9c0d72490970e2019bb1d2284f83d93f76b/torch/utils/data/distributed.py#L62
|
||||
# TODO: support tp_group>1. will fix it later
|
||||
num_replicas = 1
|
||||
if rank is None:
|
||||
rank = dist.get_rank()
|
||||
if rank < 0:
|
||||
raise ValueError(f"Invalid rank {rank}, rank should be in the interval [0, 0]")
|
||||
self.dataset = dataset
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
self.epoch = 0
|
||||
self.drop_last = drop_last
|
||||
# If the dataset length is evenly divisible by # of replicas, then there
|
||||
# is no need to drop any data, since the dataset will be split equally.
|
||||
if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type]
|
||||
# Split to nearest available length that is evenly divisible.
|
||||
# This is to ensure each rank receives the same amount of data when
|
||||
# using this Sampler.
|
||||
self.num_samples = math.ceil(
|
||||
(len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
self.shuffle = shuffle
|
||||
self.seed = seed
|
||||
self.start_index = 0
|
||||
self.use_tp = use_tp
|
||||
super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last)
|
||||
self.start_index: int = 0
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
if self.use_tp:
|
||||
# TODO Add support for tp_group not equal to 1
|
||||
pass
|
||||
# adpated from https://github.com/pytorch/pytorch/blob/4979f9c0d72490970e2019bb1d2284f83d93f76b/torch/utils/data/distributed.py#L96
|
||||
if self.shuffle:
|
||||
# deterministically shuffle based on epoch and seed
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.seed + self.epoch)
|
||||
indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
|
||||
else:
|
||||
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
|
||||
|
||||
if not self.drop_last:
|
||||
# add extra samples to make it evenly divisible
|
||||
padding_size = self.total_size - len(indices)
|
||||
if padding_size <= len(indices):
|
||||
indices += indices[:padding_size]
|
||||
else:
|
||||
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
|
||||
else:
|
||||
# remove tail of data to make it evenly divisible.
|
||||
indices = indices[: self.total_size]
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
# subsample
|
||||
indices = indices[
|
||||
: self.total_size : self.num_replicas
|
||||
] # num_replicas=tp_group=1, we only support tp_group==1 for now
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
return iter(indices)
|
||||
|
||||
else:
|
||||
iterator = super().__iter__()
|
||||
indices = list(iterator)
|
||||
indices = indices[self.start_index :]
|
||||
return iter(indices)
|
||||
iterator = super().__iter__()
|
||||
indices = list(iterator)
|
||||
indices = indices[self.start_index :]
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.num_samples - self.start_index
|
||||
|
||||
def set_start_index(self, start_index: int) -> None:
|
||||
self.start_index = start_index
|
||||
|
||||
|
||||
def setup_distributed_dataloader(
|
||||
dataset: DatasetType,
|
||||
batch_size: int = 1,
|
||||
shuffle: bool = False,
|
||||
seed: int = 1024,
|
||||
drop_last: bool = False,
|
||||
pin_memory: bool = False,
|
||||
num_workers: int = 0,
|
||||
collate_fn: Callable[[Sequence[Dict[str, Union[str, List[int]]]]], Dict[str, torch.Tensor]] = None,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
use_tp: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Setup dataloader for distributed training.
|
||||
"""
|
||||
_kwargs = kwargs.copy()
|
||||
process_group = process_group or _get_default_group()
|
||||
sampler = StatefulDistributedSampler(
|
||||
dataset=dataset,
|
||||
num_replicas=process_group.size() if not use_tp else 1,
|
||||
rank=process_group.rank(),
|
||||
shuffle=shuffle,
|
||||
seed=seed,
|
||||
drop_last=drop_last,
|
||||
use_tp=use_tp,
|
||||
)
|
||||
|
||||
# Deterministic dataloader
|
||||
def seed_worker(worker_id: int) -> None:
|
||||
worker_seed = seed
|
||||
np.random.seed(worker_seed)
|
||||
torch.manual_seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
return DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
num_workers=num_workers,
|
||||
collate_fn=collate_fn,
|
||||
pin_memory=pin_memory,
|
||||
drop_last=drop_last,
|
||||
worker_init_fn=seed_worker,
|
||||
**_kwargs,
|
||||
)
|
||||
|
|
|
@ -95,17 +95,27 @@ def supervised_tokenize_sft(
|
|||
|
||||
target_turn = turns[target_turn_index - 1]
|
||||
prompt = template.get_prompt(2 * target_turn)
|
||||
chunks, require_loss = split_templated_prompt_into_chunks(template.messages[: 2 * target_turn], prompt)
|
||||
chunks, require_loss = split_templated_prompt_into_chunks(
|
||||
template.messages[: 2 * target_turn], prompt, conversation_template.end_of_assistant
|
||||
)
|
||||
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
|
||||
|
||||
labels = [ignore_index] * len(tokenized)
|
||||
label_decode = []
|
||||
for start, end in zip(starts, ends):
|
||||
if end == len(tokenized):
|
||||
tokenized = tokenized + [tokenizer.eos_token_id]
|
||||
labels = labels + [ignore_index]
|
||||
labels[start : end + 1] = tokenized[start : end + 1]
|
||||
label_decode.append(tokenizer.decode(tokenized[start : end + 1], skip_special_tokens=False))
|
||||
labels[start:end] = tokenized[start:end]
|
||||
|
||||
# truncate the sequence at the last token that requires loss calculation
|
||||
to_truncate_len = 0
|
||||
for i in range(len(tokenized) - 1, -1, -1):
|
||||
if labels[i] == ignore_index:
|
||||
to_truncate_len += 1
|
||||
else:
|
||||
break
|
||||
tokenized = tokenized[: len(tokenized) - to_truncate_len]
|
||||
labels = labels[: len(labels) - to_truncate_len]
|
||||
|
||||
if tokenizer.bos_token_id is not None:
|
||||
if tokenized[0] != tokenizer.bos_token_id:
|
||||
|
@ -121,10 +131,20 @@ def supervised_tokenize_sft(
|
|||
labels[-1] = tokenizer.eos_token_id
|
||||
|
||||
# For some model without bos/eos may raise the following errors
|
||||
try:
|
||||
inputs_decode = tokenizer.decode(tokenized)
|
||||
except TypeError as e:
|
||||
raise TypeError(str(e) + f"\nUnable to decode input_ids: {tokenized}")
|
||||
inputs_decode = tokenizer.decode(tokenized)
|
||||
start = 0
|
||||
end = 0
|
||||
label_decode = []
|
||||
for i in range(len(labels)):
|
||||
if labels[i] == ignore_index:
|
||||
if start != end:
|
||||
label_decode.append(tokenizer.decode(labels[start + 1 : i], skip_special_tokens=False))
|
||||
start = i
|
||||
end = i
|
||||
else:
|
||||
end = i
|
||||
if i == len(labels) - 1:
|
||||
label_decode.append(tokenizer.decode(labels[start + 1 :], skip_special_tokens=False))
|
||||
|
||||
# Check if all labels are ignored, this may happen when the tokenized length is too long
|
||||
if labels.count(ignore_index) == len(labels):
|
||||
|
@ -191,7 +211,10 @@ def tokenize_prompt_dataset(
|
|||
|
||||
# Prepare data
|
||||
prompt = template.get_prompt(target_turn, add_generation_prompt=True)
|
||||
tokenized = tokenizer([prompt], add_special_tokens=False)["input_ids"][0]
|
||||
chunks, require_loss = split_templated_prompt_into_chunks(
|
||||
template.messages[:target_turn], prompt, conversation_template.end_of_assistant
|
||||
)
|
||||
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
|
||||
if tokenizer.bos_token_id is not None:
|
||||
if tokenized[0] != tokenizer.bos_token_id:
|
||||
tokenized = [tokenizer.bos_token_id] + tokenized
|
||||
|
@ -219,7 +242,9 @@ def apply_rlhf_data_format(
|
|||
):
|
||||
target_turn = int(len(template.messages) / 2)
|
||||
prompt = template.get_prompt(target_turn * 2)
|
||||
chunks, require_loss = split_templated_prompt_into_chunks(template.messages[: 2 * target_turn], prompt)
|
||||
chunks, require_loss = split_templated_prompt_into_chunks(
|
||||
template.messages[: 2 * target_turn], prompt, template.end_of_assistant
|
||||
)
|
||||
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
|
||||
loss_mask = [0] * len(tokenized)
|
||||
mask_token = tokenizer.eos_token_id or tokenizer.pad_token_id
|
||||
|
@ -232,8 +257,8 @@ def apply_rlhf_data_format(
|
|||
if end == len(tokenized):
|
||||
tokenized = tokenized + [tokenizer.eos_token_id]
|
||||
loss_mask = loss_mask + [1]
|
||||
loss_mask[start : end + 1] = [1] * len(loss_mask[start : end + 1])
|
||||
label_decode.append(tokenizer.decode(tokenized[start : end + 1], skip_special_tokens=False))
|
||||
loss_mask[start:end] = [1] * len(loss_mask[start:end])
|
||||
label_decode.append(tokenizer.decode(tokenized[start:end], skip_special_tokens=False))
|
||||
if tokenizer.bos_token_id is not None:
|
||||
if tokenized[0] != tokenizer.bos_token_id:
|
||||
tokenized = [tokenizer.bos_token_id] + tokenized
|
||||
|
|
|
@ -113,20 +113,25 @@ def tokenize_and_concatenate(tokenizer: PreTrainedTokenizer, text: List[str], re
|
|||
return input_ids, loss_starts, loss_ends
|
||||
|
||||
|
||||
def split_templated_prompt_into_chunks(messages: List[Dict[str, str]], prompt: str):
|
||||
def split_templated_prompt_into_chunks(messages: List[Dict[str, str]], prompt: str, end_of_assistant: str):
|
||||
# Seperate templated prompt into chunks by human/assistant's lines, prepare data for tokenize_and_concatenate
|
||||
start_idx = 0
|
||||
chunks = []
|
||||
require_loss = []
|
||||
for line in messages:
|
||||
content_length = len(line["content"])
|
||||
first_occur = prompt.find(line["content"], start_idx)
|
||||
if line["role"].lower() == "assistant" and end_of_assistant in prompt[first_occur + content_length :]:
|
||||
content_length = (
|
||||
prompt.find(end_of_assistant, first_occur + content_length) + len(end_of_assistant) - first_occur
|
||||
)
|
||||
if prompt[first_occur - 1] != " ":
|
||||
chunks.append(prompt[start_idx:first_occur])
|
||||
chunks.append(prompt[first_occur : first_occur + len(line["content"])])
|
||||
chunks.append(prompt[first_occur : first_occur + content_length])
|
||||
else:
|
||||
chunks.append(prompt[start_idx : first_occur - 1])
|
||||
chunks.append(prompt[first_occur - 1 : first_occur + len(line["content"])])
|
||||
start_idx = first_occur + len(line["content"])
|
||||
chunks.append(prompt[first_occur - 1 : first_occur + content_length])
|
||||
start_idx = first_occur + content_length
|
||||
if line["role"].lower() == "assistant":
|
||||
require_loss.append(False)
|
||||
require_loss.append(True)
|
||||
|
|
|
@ -32,3 +32,9 @@ class Critic(BaseModel):
|
|||
)
|
||||
values = self.value_head(sequence_hidden_states).squeeze(-1) # ensure shape is (B, sequence length)
|
||||
return values
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.get_input_embeddings()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.model.get_output_embeddings()
|
||||
|
|
|
@ -36,3 +36,9 @@ class RewardModel(BaseModel):
|
|||
)
|
||||
values = self.value_head(sequence_hidden_states).squeeze(-1) # Ensure shape is (B,)
|
||||
return values
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.get_input_embeddings()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.model.get_output_embeddings()
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
{
|
||||
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}{% if system_message is defined %}{{ system_message }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|im_start|>user\\n' + content + '<|im_end|>\\n<|im_start|>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '<|im_end|>' + '\\n' }}{% endif %}{% endfor %}",
|
||||
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
|
||||
"stop_ids": [
|
||||
7
|
||||
],
|
||||
"end_of_assistant": "<|im_end|>"
|
||||
}
|
|
@ -0,0 +1,9 @@
|
|||
{
|
||||
"chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
|
||||
"stop_ids": [
|
||||
151645,
|
||||
151643
|
||||
],
|
||||
"end_of_assistant": "<|im_end|>"
|
||||
}
|
|
@ -0,0 +1,12 @@
|
|||
{
|
||||
"chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
|
||||
"stop_ids": [
|
||||
31007,
|
||||
326,
|
||||
30962,
|
||||
437,
|
||||
31007
|
||||
],
|
||||
"end_of_assistant": "<|im_end|>"
|
||||
}
|
|
@ -0,0 +1,8 @@
|
|||
{
|
||||
"chat_template": "{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",
|
||||
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
|
||||
"stop_ids": [
|
||||
2
|
||||
],
|
||||
"end_of_assistant": "<|user|>"
|
||||
}
|
|
@ -1,7 +0,0 @@
|
|||
{
|
||||
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true and not '<<SYS>>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\\n\\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don\\'t know the answer to a question, please don\\'t share false information.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>>\\n' + content.strip() + '\\n<</SYS>>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}",
|
||||
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
|
||||
"stop_ids": [
|
||||
2
|
||||
]
|
||||
}
|
|
@ -1,7 +0,0 @@
|
|||
{
|
||||
"chat_template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
|
||||
"stop_ids": [
|
||||
2
|
||||
]
|
||||
}
|
|
@ -3,5 +3,6 @@
|
|||
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
|
||||
"stop_ids": [
|
||||
2
|
||||
]
|
||||
],
|
||||
"end_of_assistant": "<|im_end|>"
|
||||
}
|
|
@ -3,5 +3,6 @@
|
|||
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
|
||||
"stop_ids": [
|
||||
2
|
||||
]
|
||||
],
|
||||
"end_of_assistant": "</s>"
|
||||
}
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
{
|
||||
"chat_template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}",
|
||||
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
|
||||
"stop_ids": [
|
||||
100001
|
||||
],
|
||||
"end_of_assistant": "<|end▁of▁sentence|>"
|
||||
}
|
|
@ -3,5 +3,6 @@
|
|||
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
|
||||
"stop_ids": [
|
||||
2
|
||||
]
|
||||
],
|
||||
"end_of_assistant": "</s>"
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
"chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
|
||||
"stop_ids": [
|
||||
null
|
||||
]
|
||||
50256
|
||||
],
|
||||
"end_of_assistant": "<|im_end|>"
|
||||
}
|
|
@ -3,5 +3,6 @@
|
|||
"system_message": null,
|
||||
"stop_ids": [
|
||||
2
|
||||
]
|
||||
],
|
||||
"end_of_assistant": "</s>"
|
||||
}
|
|
@ -1,7 +0,0 @@
|
|||
{
|
||||
"chat_template": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}",
|
||||
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
|
||||
"stop_ids": [
|
||||
2
|
||||
]
|
||||
}
|
|
@ -1,7 +1,9 @@
|
|||
# Examples
|
||||
|
||||
|
||||
## Table of Contents
|
||||
|
||||
|
||||
- [Examples](#examples)
|
||||
- [Table of Contents](#table-of-contents)
|
||||
- [Install Requirements](#install-requirements)
|
||||
|
@ -27,28 +29,36 @@
|
|||
- [Alternative Option For RLHF: Direct Preference Optimization](#alternative-option-for-rlhf-direct-preference-optimization)
|
||||
- [DPO Stage 1: Supervised Instruction Tuning](#dpo-training-stage1---supervised-instructs-tuning)
|
||||
- [DPO Stage 2: DPO Training](#dpo-training-stage2---dpo-training)
|
||||
- [List of Supported Models](#list-of-supported-models)
|
||||
- [Hardware Requirements](#hardware-requirements)
|
||||
- [Inference example](#inference-example)
|
||||
- [Attention](#attention)
|
||||
|
||||
|
||||
---
|
||||
|
||||
|
||||
## Install requirements
|
||||
|
||||
|
||||
```shell
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
## Get Start with ColossalRun
|
||||
|
||||
You can use colossalai run to launch multi-nodes training:
|
||||
|
||||
You can use colossalai run to launch multi-node training:
|
||||
```
|
||||
colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \
|
||||
train.py --OTHER_CONFIGURATIONS
|
||||
```
|
||||
Here is a sample hostfile:
|
||||
|
||||
|
||||
```
|
||||
hostname1
|
||||
hostname2
|
||||
|
@ -56,21 +66,29 @@ hostname3
|
|||
hostname4
|
||||
```
|
||||
|
||||
Make sure master node can access all nodes (including itself) by ssh without password. Here are some other arguments.
|
||||
|
||||
Make sure the master node can access all nodes (including itself) by ssh without a password. Here are some other arguments.
|
||||
|
||||
|
||||
- nnodes: number of nodes used in the training
|
||||
- nproc-per-node: specifies the number of processes to be launched per node
|
||||
- rdzv-endpoint: address of the host node
|
||||
|
||||
|
||||
### Training Configuration
|
||||
|
||||
This section gives a simple introduction on different training strategies that you can use and how to use them with our boosters and plugins to reduce training time and VRAM consumption. For more detail regarding training strategies, please refer to [here](https://colossalai.org/docs/concepts/paradigms_of_parallelism). For details regarding boosters and plugins, please refer to [here](https://colossalai.org/docs/basics/booster_plugins).
|
||||
|
||||
This section gives a simple introduction on different training strategies that you can use and how to use them with our boosters and plugins to reduce training time and VRAM consumption. For more details regarding training strategies, please refer to [here](https://colossalai.org/docs/concepts/paradigms_of_parallelism). For details regarding boosters and plugins, please refer to [here](https://colossalai.org/docs/basics/booster_plugins).
|
||||
|
||||
|
||||
<details><summary><b>Gemini</b></summary>
|
||||
|
||||
|
||||
<details><summary><b>Gemini (Zero3)</b></summary>
|
||||
|
||||
|
||||
This plugin implements Zero-3 with chunk-based and heterogeneous memory management. It can train large models without much loss in speed. It also does not support local gradient accumulation. More details can be found in [Gemini Doc](https://colossalai.org/docs/features/zero_with_chunk).
|
||||
|
||||
|
||||
Below shows how to use the gemini in SFT training.
|
||||
```
|
||||
colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile train_sft.py \
|
||||
|
@ -89,13 +107,17 @@ colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile trai
|
|||
--use_wandb
|
||||
```
|
||||
|
||||
|
||||
</details>
|
||||
|
||||
<details><summary><b>Gemini-Auto</b></summary>
|
||||
|
||||
This option use gemini and will automatically offload tensors with low priority to cpu. It also does not support local gradient accumulation. More details can be found in [Gemini Doc](https://colossalai.org/docs/features/zero_with_chunk).
|
||||
<details><summary><b>Gemini-Auto (Zero3 with Auto-Resource-Allocation-Policy)</b></summary>
|
||||
|
||||
Below shows how to use the gemin-auto in SFT training.
|
||||
|
||||
This option uses gemini and will automatically offload tensors with low priority to cpu. It also does not support local gradient accumulation. More details can be found in [Gemini Doc](https://colossalai.org/docs/features/zero_with_chunk).
|
||||
|
||||
|
||||
Below shows how to use the gemini-auto in SFT training.
|
||||
```
|
||||
colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile train_sft.py \
|
||||
--pretrain $PRETRAINED_MODEL_PATH \
|
||||
|
@ -113,13 +135,18 @@ colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile trai
|
|||
--use_wandb
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details><summary><b>Zero2</b></summary>
|
||||
|
||||
This option will distribute the optimizer parameters and the gradient to multiple GPUs and won't offload weights to cpu. It uses reduce and gather to synchronize gradients and weights. It does not support local gradient accumulation. Though you can accumulate gradient if you insist, it cannot reduce communication cost. That is to say, it's not a good idea to use Zero-2 with pipeline parallelism.
|
||||
|
||||
This option will distribute the optimizer parameters and the gradient to multiple GPUs and won't offload weights to cpu. It uses reduce and gather to synchronize gradients and weights. It does not support local gradient accumulation. Though you can accumulate gradients if you insist, it cannot reduce communication cost. That is to say, it's not a good idea to use Zero-2 with pipeline parallelism.
|
||||
|
||||
|
||||
Below shows how to use the zero2 in SFT training.
|
||||
```
|
||||
|
@ -139,12 +166,17 @@ colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile trai
|
|||
--use_wandb
|
||||
```
|
||||
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
|
||||
<details><summary><b>Zero2CPU</b></summary>
|
||||
|
||||
This option will distribute the optimizer parameters and the gradient to multiple GPUs as well as offload parameters to cpu. It does not support local gradient accumulation. Though you can accumulate gradient if you insist, it cannot reduce communication cost.
|
||||
|
||||
This option will distribute the optimizer parameters and the gradient to multiple GPUs as well as offload parameters to cpu. It does not support local gradient accumulation. Though you can accumulate gradients if you insist, it cannot reduce communication cost.
|
||||
|
||||
|
||||
Below shows how to use the zero2-cpu in SFT training.
|
||||
```
|
||||
|
@ -164,11 +196,20 @@ colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile trai
|
|||
--use_wandb
|
||||
```
|
||||
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details><summary><b>Tensor Parallelism</b></summary>
|
||||
|
||||
This option support Tensor Parallelism (TP). Note that if you want to use TP, zero and pipeline parallelism will be disabled. TP split large model weights/optimizer parameters/gradients into multiple small ones and distributes them to multiple GPUs, hence it is recommended to use TP when your model is large (e.g. 20B and above) or your training algorithm consumes a lot of memory (e.g. PPO).
|
||||
|
||||
This option supports Tensor Parallelism (TP). Note that if you want to use TP, TP split large model weights/optimizer parameters/gradients into multiple small ones and distributes them to multiple GPUs, hence it is recommended to use TP when your model is large (e.g. 20B and above) or your training algorithm consumes a lot of memory (e.g. PPO). Currently, we have added support for TP for the following model architectures.
|
||||
|
||||
|
||||
```
|
||||
bert, LLaMA, T5, GPT2, GPT-J, OPT, Bloom, Whisper, Sam, Blip2, ChatGLM (up to ChatGLM2), Falcon, Qwen2
|
||||
```
|
||||
|
||||
|
||||
Below shows how to use the TP in PPO training.
|
||||
```
|
||||
|
@ -181,7 +222,7 @@ colossalai run --nproc_per_node 4 --hostfile hostfile --master_port 30039 train_
|
|||
--pretrain_dataset ${ptx_dataset[@]} \
|
||||
--ptx_batch_size 1 \
|
||||
--ptx_coef 0.0 \
|
||||
--plugin "zero2" \
|
||||
--plugin "3d" \
|
||||
--save_interval 200 \
|
||||
--save_path $SAVE_DIR \
|
||||
--num_episodes 2000 \
|
||||
|
@ -200,13 +241,87 @@ colossalai run --nproc_per_node 4 --hostfile hostfile --master_port 30039 train_
|
|||
--use_wandb
|
||||
```
|
||||
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details><summary><b>Sequence Parallelism</b></summary>
|
||||
|
||||
|
||||
This option supports Sequence Parallelism (SP). It is recommended to use SP when your input sequence is very long (e.g. 50K and above). Please refer to this [SP Doc](https://github.com/hpcaitech/ColossalAI/blob/b96c6390f4363f58c0df56c0ca28755f8a5f1aa2/examples/tutorial/sequence_parallel/README.md?plain=1#L1) for more information.
|
||||
|
||||
Below shows how to use the SP in SFT training.
|
||||
```
|
||||
# use the `split_gather` or `ring` sp mode
|
||||
colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile train_sft.py \
|
||||
--pretrain $PRETRAINED_MODEL_PATH \
|
||||
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
|
||||
--dataset ${dataset[@]} \
|
||||
--save_interval 5000 \
|
||||
--save_path $SAVE_DIR \
|
||||
--config_file $CONFIG_FILE \
|
||||
--plugin 3d \
|
||||
--tp 4 \ # TP size, nproc_per_node must be divisible by it
|
||||
--sp 1 \ # SP size, must be 1
|
||||
--sp_mode 'split_gather' \ # or 'ring'
|
||||
--enable_sequence_parallelism \ # must be set
|
||||
--batch_size 4 \
|
||||
--max_epochs 1 \
|
||||
--accumulation_steps 4 \
|
||||
--lr 2e-5 \
|
||||
--max_len 2048 \
|
||||
--use_wandb
|
||||
|
||||
# use the `all_to_all` sp mode
|
||||
colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile train_sft.py \
|
||||
--pretrain $PRETRAINED_MODEL_PATH \
|
||||
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
|
||||
--dataset ${dataset[@]} \
|
||||
--save_interval 5000 \
|
||||
--save_path $SAVE_DIR \
|
||||
--config_file $CONFIG_FILE \
|
||||
--plugin 3d \
|
||||
--tp 1 \ # TP size, must be 1
|
||||
--sp 4 \ # SP size, nproc_per_node must be divisible by it
|
||||
--sp_mode 'all_to_all' \
|
||||
--enable_sequence_parallelism \ # must be set
|
||||
--batch_size 4 \
|
||||
--max_epochs 1 \
|
||||
--accumulation_steps 4 \
|
||||
--lr 2e-5 \
|
||||
--max_len 2048 \
|
||||
--use_wandb
|
||||
```
|
||||
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details><summary><b>Advanced Training Configuration with the Hybrid Plugin</b></summary>
|
||||
|
||||
User can use our HybridParallelPlugin for more advanced policy control. Currently, we have added support for the following model architectures.
|
||||
|
||||
|
||||
```
|
||||
bert, LLaMA, T5, GPT2, GPT-J, OPT, Bloom, Whisper, Sam, Blip2, ChatGLM (up to ChatGLM2), Falcon, Qwen2
|
||||
```
|
||||
|
||||
- We support mixing tensor parallelism with zero1/zero2/zero3:
|
||||
to do that, set both `tp` and `zero_stage`
|
||||
- We support mixing tensor parallelism with pipeline parallelism:
|
||||
to do that, set both `tp` and `pp`
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
|
||||
<details><summary><b>Gradient Checkpointing</b></summary>
|
||||
|
||||
|
||||
This option saves VRAM consumption by selectively recomputing some of the intermediate value on-the-fly during the backward pass, rather than storing them in memory.
|
||||
|
||||
|
||||
To enable gradient checkpointing, add --grad_checkpoint to your training script.
|
||||
```
|
||||
colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile train_sft.py \
|
||||
|
@ -226,12 +341,16 @@ colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile trai
|
|||
--use_wandb
|
||||
```
|
||||
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details><summary><b>Flash Attention</b></summary>
|
||||
|
||||
|
||||
Details about flash attention can be found in the paper: [FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness](https://arxiv.org/abs/2205.14135).
|
||||
|
||||
|
||||
To enable flash attention, add --use_flash_attn to your training script.
|
||||
```
|
||||
colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile train_sft.py \
|
||||
|
@ -251,11 +370,15 @@ colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile trai
|
|||
--use_wandb
|
||||
```
|
||||
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details><summary><b>Low Rank Adaption</b></summary>
|
||||
|
||||
Details about Low Rank Adaption (LoRA) can be found in the paper: [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685). It dramatically reduce the VRAM consumption at the cost of sacrifice model capability. It is suitable for training LLM with constrained resources.
|
||||
|
||||
Details about Low Rank Adaption (LoRA) can be found in the paper: [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685). It dramatically reduces the VRAM consumption at the cost of sacrifice model capability. It is suitable for training LLM with constrained resources.
|
||||
|
||||
|
||||
To enable LoRA, set --lora_rank to a positive value (usually between 20 and 64).
|
||||
```
|
||||
|
@ -276,23 +399,26 @@ colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile trai
|
|||
--use_wandb
|
||||
```
|
||||
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details><summary><b>Other Training Arguments</b></summary>
|
||||
|
||||
- grad_clip: gradient larger than this value will be clipped.
|
||||
|
||||
- grad_clip: gradients larger than this value will be clipped.
|
||||
- weight_decay: weight decay hyper-parameter.
|
||||
- warmup_steps: number of warmup steps used in setting up the learning rate scheduler.
|
||||
- pretrain: pretrain model path, weights will be loaded from this pretrained model unless checkpoint_path is provided.
|
||||
- tokenizer_dir: specify where to load the tokenizer, if not provided, tokenizer will be loaded from pretrain model path.
|
||||
- dataset: a list of strings, each is a path to a folder contains buffered dataset files in arrow format.
|
||||
- tokenizer_dir: specify where to load the tokenizer, if not provided, tokenizer will be loaded from the pretrained model path.
|
||||
- dataset: a list of strings, each is a path to a folder containing buffered dataset files in arrow format.
|
||||
- checkpoint_path: if provided, will load weights from the checkpoint_path.
|
||||
- config_file: path to store the training config file.
|
||||
- save_dir: path to store the model checkpoints.
|
||||
- max_length: input will be padded/truncate to max_length before feeding to the model.
|
||||
- max_epochs: number of epoch to train.
|
||||
- max_length: input will be padded/truncated to max_length before feeding to the model.
|
||||
- max_epochs: number of epochs to train.
|
||||
- batch_size: training batch size.
|
||||
- mixed_precision: precision to use in training. Support 'fp16' and 'bf16'. Note that some device may not support the 'bf16' option, please refer to [Nvidia](https://developer.nvidia.com/) to check compatibility.
|
||||
- mixed_precision: precision to use in training. Support 'fp16' and 'bf16'. Note that some devices may not support the 'bf16' option, please refer to [Nvidia](https://developer.nvidia.com/) to check compatibility.
|
||||
- save_interval: save the model weights as well as optimizer/scheduler states every save_interval steps/episodes.
|
||||
- merge_lora_weights: whether to merge lora weights before saving the model
|
||||
- lr: the learning rate used in training.
|
||||
|
@ -300,15 +426,20 @@ colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile trai
|
|||
- log_dir: path to store the log.
|
||||
- use_wandb: if this flag is up, you can view logs on wandb.
|
||||
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
### RLHF Training Stage1 - Supervised Instructs Tuning
|
||||
|
||||
|
||||
Stage1 is supervised instructs fine-tuning (SFT). This step is a crucial part of the RLHF training process, as it involves training a machine learning model using human-provided instructions to learn the initial behavior for the task at hand. Here's a detailed guide on how to SFT your LLM with ColossalChat:
|
||||
|
||||
|
||||
#### Step 1: Data Collection
|
||||
The first step in Stage 1 is to collect a dataset of human demonstrations of the following format.
|
||||
|
||||
|
||||
```json
|
||||
[
|
||||
{"messages":
|
||||
|
@ -328,45 +459,69 @@ The first step in Stage 1 is to collect a dataset of human demonstrations of the
|
|||
]
|
||||
```
|
||||
|
||||
|
||||
#### Step 2: Preprocessing
|
||||
Once you have collected your SFT dataset, you will need to preprocess it. This involves four steps: data cleaning, data deduplication, formatting and tokenization. In this section, we will focus on formatting and tokenization.
|
||||
|
||||
|
||||
In this code we provide a flexible way for users to set the conversation template for formatting chat data using Huggingface's newest feature--- chat template. Please follow the following steps to define your chat template and preprocess your data.
|
||||
|
||||
|
||||
- Step 1: (Optional). Define your conversation template. You need to provide a conversation template config file similar to the config files under the ./config/conversation_template directory. This config should include the following fields.
|
||||
```json
|
||||
{
|
||||
"chat_template": (Optional), A string of chat_template used for formatting chat data. If not set (None), will use the default chat template of the provided tokenizer. If a path to a huggingface model or local model is provided, will use the chat_template of that model. To use a custom chat template, you need to manually set this field. For more details on how to write a chat template in Jinja format, please read https://huggingface.co/docs/transformers/main/chat_templating,
|
||||
"system_message": A string of system message to be added at the beginning of the prompt. If no is provided (None), no system message will be added,
|
||||
"stop_ids": (Optional), A list of string indicating the end of assistant's response during the rollout stage of PPO training. It's recommended to set this manually for PPO training. If not set, will set to tokenizer.eos_token_ids automatically,
|
||||
"end_of_assistant": The token(s) in string that denotes the end of assistance's response. For example, in the ChatGLM2 prompt format,
|
||||
```
|
||||
<|im_start|>system
|
||||
system messages
|
||||
|
||||
<|im_end|>
|
||||
<|im_start|>user
|
||||
How far is the moon? <|im_end|>
|
||||
<|im_start|>assistant\n The moon is about 384,400 kilometers away from Earth.<|im_end|>...
|
||||
```
|
||||
the end_of_assistant tokens are "<|im_end|>"
|
||||
"stop_ids": (Optional), A list of integers corresponds to the `end_of_assistant` tokens that indicate the end of assistance's response during the rollout stage of PPO training. It's recommended to set this manually for PPO training. If not set, will set to tokenizer.eos_token_ids automatically
|
||||
}
|
||||
```
|
||||
On your first run of the data preparation script, you only need to define the "chat_template" (if you want to use custom chat template) and the "system message" (if you want to use a custom system message),
|
||||
|
||||
|
||||
- Step 2: Run the data preparation script--- [prepare_sft_dataset.sh](./examples/data_preparation_scripts/prepare_sft_dataset.sh). Note that whether or not you have skipped the first step, you need to provide the path to the conversation template config file (via the conversation_template_config arg). If you skipped the first step, an auto-generated conversation template will be stored at the designated file path.
|
||||
|
||||
|
||||
- Step 3: (Optional) Check the correctness of the processed data. We provided an easy way for you to do a manual checking on the processed data by checking the "$SAVE_DIR/jsonl/part-XXXX.jsonl" files.
|
||||
|
||||
|
||||
Finishing the above steps, you have converted the raw conversation to the designated chat format and tokenized the formatted conversation, calculate input_ids, labels, attention_masks and buffer those into binary dataset files under "$SAVE_DIR/arrow/part-XXXX" folders.
|
||||
|
||||
|
||||
For example, our Colossal-LLaMA-2 format looks like,
|
||||
```
|
||||
<s> A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
|
||||
|
||||
|
||||
Human: <s> what are some pranks with a pen i can do?</s> Assistant: <s> Are you looking for practical joke ideas?</s>
|
||||
...
|
||||
```
|
||||
|
||||
|
||||
#### Step 3: Training
|
||||
Choose a suitable model architecture for your task. Note that your model should be compatible with the tokenizer that you used to tokenize the SFT dataset. You can run [train_sft.sh](./examples/training_scripts/train_sft.sh) to start a supervised instructs fine-tuning. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options.
|
||||
|
||||
|
||||
### RLHF Training Stage2 - Training Reward Model
|
||||
|
||||
|
||||
Stage2 trains a reward model, which obtains corresponding scores by manually ranking different outputs for the same prompt and supervises the training of the reward model.
|
||||
|
||||
|
||||
#### Step 1: Data Collection
|
||||
Below shows the preference dataset format used in training the reward model.
|
||||
|
||||
|
||||
```json
|
||||
[
|
||||
{"context": [
|
||||
|
@ -394,42 +549,54 @@ Below shows the preference dataset format used in training the reward model.
|
|||
]
|
||||
```
|
||||
|
||||
|
||||
#### Step 2: Preprocessing
|
||||
Similar to the second step in the previous stage, we format the reward data into the same structured format as used in step 2 of the SFT stage. You can run [prepare_preference_dataset.sh](./examples/data_preparation_scripts/prepare_preference_dataset.sh) to prepare the preference data for reward model training.
|
||||
|
||||
|
||||
#### Step 3: Training
|
||||
You can run [train_rm.sh](./examples/training_scripts/train_rm.sh) to start the reward model training. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options.
|
||||
|
||||
|
||||
#### Features and Tricks in RM Training
|
||||
|
||||
|
||||
- We recommend using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf)and[rm-static](https://huggingface.co/datasets/Dahoas/rm-static) datasets for training the reward model.
|
||||
- We support 2 kinds of loss function named `log_sig`(used by OpenAI) and `log_exp`(used by Anthropic).
|
||||
- We log the training accuracy `train/acc`, `reward_chosen` and `reward_rejected` to monitor progress during training.
|
||||
- We use cosine-reducing lr-scheduler for RM training.
|
||||
- We set value_head as 1 liner layer and initialize the weight of value_head using N(0,1/(d_model + 1)) distribution.
|
||||
- We set value_head as one liner layer and initialize the weight of value_head using the N(0,1/(d_model + 1)) distribution.
|
||||
|
||||
|
||||
#### Note on Reward Model Training
|
||||
|
||||
Before you move on the next stage, please check the following list to ensure that your reward model is stable and robust. You can check the reward chart and the accuracy chart on wandb.
|
||||
|
||||
Before you move on to the next stage, please check the following list to ensure that your reward model is stable and robust. You can check the reward chart and the accuracy chart on wandb.
|
||||
- The mean reward for chosen data is much higher than those for rejected data
|
||||
- The accuracy is larger than 0.5 by a significant margin (usually should be greater than 0.6)
|
||||
- Optional:check the reward is positive for chosen data vice versa
|
||||
|
||||
|
||||
Your training reward curves should look similar to the following charts.
|
||||
<p align="center">
|
||||
<img width="1000" alt="image" src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/mean_reward_chart.png">
|
||||
</p>
|
||||
|
||||
|
||||
### RLHF Training Stage3 - Proximal Policy Optimization
|
||||
|
||||
|
||||
In stage3 we will use reinforcement learning algorithm--- Proximal Policy Optimization (PPO), which is the most complex part of the training process:
|
||||
|
||||
|
||||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/stage-3.jpeg" width=800/>
|
||||
</p>
|
||||
|
||||
|
||||
#### Step 1: Data Collection
|
||||
PPO uses two kind of training data--- the prompt data and the pretrain data (optional). The first dataset is mandatory, data samples within the prompt dataset ends with a line from "human" and thus the "assistant" needs to generate a response to answer to the "human". Note that you can still use conversation that ends with a line from the "assistant", in that case, the last line will be dropped. Here is an example of the prompt dataset format.
|
||||
PPO uses two kinds of training data--- the prompt data and the pretrain data (optional). The first dataset is mandatory, data samples within the prompt dataset ends with a line from "human" and thus the "assistant" needs to generate a response to answer to the "human". Note that you can still use conversation that ends with a line from the "assistant", in that case, the last line will be dropped. Here is an example of the prompt dataset format.
|
||||
|
||||
|
||||
```json
|
||||
[
|
||||
|
@ -445,8 +612,10 @@ PPO uses two kind of training data--- the prompt data and the pretrain data (opt
|
|||
]
|
||||
```
|
||||
|
||||
|
||||
The second dataset--- pretrained dataset is optional, provide it if you want to use the ptx loss introduced in the [InstructGPT paper](https://arxiv.org/abs/2203.02155). It follows the following format.
|
||||
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
|
@ -459,11 +628,14 @@ The second dataset--- pretrained dataset is optional, provide it if you want to
|
|||
#### Step 2: Preprocessing
|
||||
To prepare the prompt dataset for PPO training, simply run [prepare_prompt_dataset.sh](./examples/data_preparation_scripts/prepare_prompt_dataset.sh)
|
||||
|
||||
You can use the SFT dataset you prepared in the SFT stage or prepare a new one from different source for the ptx dataset. The ptx data is used to calculate ptx loss, which stablize the training according to the [InstructGPT paper](https://arxiv.org/pdf/2203.02155.pdf).
|
||||
|
||||
You can use the SFT dataset you prepared in the SFT stage or prepare a new one from different source for the ptx dataset. The ptx data is used to calculate ptx loss, which stabilizes the training according to the [InstructGPT paper](https://arxiv.org/pdf/2203.02155.pdf).
|
||||
|
||||
|
||||
#### Step 3: Training
|
||||
You can run the [train_ppo.sh](./examples/training_scripts/train_ppo.sh) to start PPO training. Here are some unique arguments for PPO, please refer to the training configuration section for other training configuration. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options.
|
||||
|
||||
|
||||
```bash
|
||||
--pretrain $PRETRAINED_MODEL_PATH \
|
||||
--rm_pretrain $PRETRAINED_MODEL_PATH \ # reward model architectural
|
||||
|
@ -482,7 +654,9 @@ You can run the [train_ppo.sh](./examples/training_scripts/train_ppo.sh) to star
|
|||
--accumulation_steps 2
|
||||
```
|
||||
|
||||
Each episode has two phases, the collect phase and the update phase. During the collect phase, we will collect experiences (answers generated by actor), store those in ExperienceBuffer. Then data in ExperienceBuffer is used during the update phase to update parameter of actor and critic.
|
||||
|
||||
Each episode has two phases, the collect phase and the update phase. During the collect phase, we will collect experiences (answers generated by the actor), store those in ExperienceBuffer. Then data in ExperienceBuffer is used during the update phase to update parameters of actor and critic.
|
||||
|
||||
|
||||
- Without tensor parallelism,
|
||||
```
|
||||
|
@ -491,6 +665,7 @@ experience buffer size
|
|||
= train_batch_size * accumulation_steps * num_process
|
||||
```
|
||||
|
||||
|
||||
- With tensor parallelism,
|
||||
```
|
||||
num_tp_group = num_process / tp
|
||||
|
@ -499,47 +674,60 @@ experience buffer size
|
|||
= train_batch_size * accumulation_steps * num_tp_group
|
||||
```
|
||||
|
||||
|
||||
### Sample Training Results Using Default Script
|
||||
#### Reward
|
||||
<p align="center">
|
||||
<img width="700" alt="image" src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/reward.png">
|
||||
</p>
|
||||
|
||||
|
||||
### Note on PPO Training
|
||||
#### Q1: My reward is negative
|
||||
Answer: Check your reward model trained in stage 1. If the reward model only generate negative reward, we actually will expect a negative reward. However, even though the reward is negative, the reward should go up.
|
||||
Answer: Check your reward model trained in stage 1. If the reward model only generates negative reward, we actually will expect a negative reward. However, even though the reward is negative, the reward should go up.
|
||||
|
||||
|
||||
#### Q2: My actor loss is negative
|
||||
Answer: This is normal for actor loss as PPO doesn't restrict the actor loss to be positive.
|
||||
|
||||
|
||||
#### Q3: My reward doesn't go up (decreases)
|
||||
Answer: The causes to this problem are two-fold. Check your reward model, make sure that it gives positive and strong reward for good cases and negative, strong reward for bad responses. You should also try different hyperparameter settings.
|
||||
Answer: The causes of this problem are two-fold. Check your reward model, make sure that it gives positive and strong reward for good cases and negative, strong reward for bad responses. You should also try different hyperparameter settings.
|
||||
|
||||
|
||||
#### Q4: Generation is garbage
|
||||
Answer: Yes, this happens and is well documented by other implementations. After training for too many episodes, the actor gradually deviate from its original state, which may leads to decrease in language modeling capabilities. A way to fix this is to add supervised loss during PPO. Set ptx_coef to a none-zero value (between 0 and 1), which balances PPO loss and sft loss.
|
||||
Answer: Yes, this happens and is well documented by other implementations. After training for too many episodes, the actor gradually deviate from its original state, which may leads to decrease in language modeling capabilities. A way to fix this is to add supervised loss during PPO. Set ptx_coef to an non-zero value (between 0 and 1), which balances PPO loss and sft loss.
|
||||
|
||||
|
||||
## Alternative Option For RLHF: Direct Preference Optimization
|
||||
|
||||
|
||||
For those seeking an alternative to Reinforcement Learning from Human Feedback (RLHF), Direct Preference Optimization (DPO) presents a compelling option. DPO, as detailed in the paper (available at [https://arxiv.org/abs/2305.18290](https://arxiv.org/abs/2305.18290)), DPO offers an low-cost way to perform RLHF and usually request less computation resources compares to PPO.
|
||||
|
||||
|
||||
### DPO Training Stage1 - Supervised Instructs Tuning
|
||||
|
||||
|
||||
Please refer the [sft section](#dpo-training-stage1---supervised-instructs-tuning) in the PPO part.
|
||||
|
||||
|
||||
### DPO Training Stage2 - DPO Training
|
||||
#### Step 1: Data Collection & Preparation
|
||||
For DPO training, you only need the preference dataset. Please follow the instruction in the [preference dataset preparation section](#rlhf-training-stage2---training-reward-model) to prepare the preference data for DPO training.
|
||||
|
||||
|
||||
#### Step 2: Training
|
||||
You can run the [train_dpo.sh](./examples/training_scripts/train_dpo.sh) to start DPO training. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options.
|
||||
|
||||
|
||||
#### DPO Result
|
||||
<p align="center">
|
||||
<img width="1000" alt="image" src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/DPO.png">
|
||||
</p>
|
||||
|
||||
|
||||
## Hardware Requirements
|
||||
For PPO, we suggest using Tensor Parallelism. The following table shows the VRAM consumption of training a 7B model on a dummy dataset with 2048 sequence length and 512 layout length with different tp_size (equal to the number of GPUs). In this experiment, we use H800 GPU with 80GB VRAM.
|
||||
For PPO, we suggest using Tensor Parallelism. The following table shows the VRAM consumption of training a 7B model on a dummy dataset with 2048 sequence length and 512 layout length with different tp_size (equal to the number of GPUs). In this experiment, we use an H800 GPU with 80GB VRAM.
|
||||
| PPO | tp=8 | tp=4 |
|
||||
|-------|---------------|---------------|
|
||||
| bs=1 | 18485.19 MB | 42934.45 MB |
|
||||
|
@ -547,19 +735,45 @@ For PPO, we suggest using Tensor Parallelism. The following table shows the VRAM
|
|||
| bs=16 | 41408.28 MB | 56778.97 MB |
|
||||
| bs=30 | 64047.42 MB | failed |
|
||||
|
||||
|
||||
For DPO, we recommend using zero2 or zero2-cpu. We tested the VRAM consumption on a dummy dataset with 2048 sequence length.
|
||||
|
||||
|
||||
- 1 H800 GPU
|
||||
- zero2-cpu, batch size=2, VRAM Usage=49873.90 MB
|
||||
- zero2-cpu, batch size=4, VRAM Usage=60998.22 MB
|
||||
- 4 H800 GPUs
|
||||
- zero2, batch size=4, VRAM Usage=67544.47 MB
|
||||
|
||||
## List of Supported Models
|
||||
|
||||
For SFT, we support the following models/series:
|
||||
- Colossal-LLaMA-2
|
||||
- ChatGLM2
|
||||
- ChatGLM3 (only with zero2, zero2_cpu plugin)
|
||||
- Baichuan2
|
||||
- LLaMA2
|
||||
- Qwen1.5-7B-Chat (with transformers==4.39.1)
|
||||
- Yi-1.5
|
||||
|
||||
For PPO and DPO, we theoratically support the following models/series (without guarantee):
|
||||
- Colossal-LLaMA-2 (tested)
|
||||
- ChatGLM2
|
||||
- Baichuan2
|
||||
- LLaMA2 (tested)
|
||||
- Qwen1.5-7B-Chat (with transformers==4.39.1)
|
||||
- Yi-1.5
|
||||
|
||||
*-* The zero2, zero2_cpu plugin also support a wide range of chat models not listed above.
|
||||
|
||||
## Inference example
|
||||
|
||||
|
||||
We support different inference options, including int8 and int4 quantization.
|
||||
For details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/inference).
|
||||
|
||||
|
||||
## Attention
|
||||
|
||||
|
||||
The examples are demos for the whole training process. You need to change the hyper-parameters to reach great performance.
|
||||
|
|
|
@ -5,7 +5,7 @@ rm -rf $SAVE_DIR/jsonl
|
|||
rm -rf $SAVE_DIR/arrow
|
||||
|
||||
python prepare_dataset.py --type sft \
|
||||
--data_input_dirs /PATH/TO/SFT/DATASET \
|
||||
--data_input_dirs "PATH/TO/SFT/DATA" \
|
||||
--conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \
|
||||
--tokenizer_dir "" \
|
||||
--data_cache_dir $SAVE_DIR/cache \
|
||||
|
|
|
@ -1 +1,5 @@
|
|||
10.20.1.82
|
||||
XXX.XX.XXX.XXX # Your master IP
|
||||
XXX.XX.XXX.XXX # Your slave IPs
|
||||
XXX.XX.XXX.XXX # Your slave IPs
|
||||
XXX.XX.XXX.XXX # Your slave IPs
|
||||
XXX.XX.XXX.XXX # Your slave IPs
|
||||
|
|
|
@ -5,12 +5,7 @@ import resource
|
|||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from coati.dataset import (
|
||||
DataCollatorForPreferenceDataset,
|
||||
StatefulDistributedSampler,
|
||||
load_tokenized_dataset,
|
||||
setup_distributed_dataloader,
|
||||
)
|
||||
from coati.dataset import DataCollatorForPreferenceDataset, StatefulDistributedSampler, load_tokenized_dataset
|
||||
from coati.models import convert_to_lora_module, disable_dropout
|
||||
from coati.trainer import DPOTrainer
|
||||
from coati.utils import load_checkpoint
|
||||
|
@ -56,6 +51,7 @@ def train(args):
|
|||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_gradient_accumulation=True,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(
|
||||
|
@ -63,6 +59,7 @@ def train(args):
|
|||
placement_policy="auto",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
)
|
||||
elif args.plugin == "zero2":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
|
@ -82,9 +79,15 @@ def train(args):
|
|||
elif args.plugin == "3d":
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=1,
|
||||
zero_stage=0,
|
||||
pp_size=args.pp,
|
||||
sp_size=args.sp,
|
||||
sequence_parallelism_mode=args.sp_mode,
|
||||
zero_stage=args.zero_stage,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
enable_sequence_parallelism=args.enable_sequence_parallelism,
|
||||
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
|
||||
parallel_output=False,
|
||||
max_norm=args.grad_clip,
|
||||
precision=args.mixed_precision,
|
||||
)
|
||||
else:
|
||||
|
@ -166,13 +169,14 @@ def train(args):
|
|||
mode_map = {"train": "train", "valid": "validation", "test": "test"}
|
||||
train_dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train", mode_map=mode_map)
|
||||
data_collator = DataCollatorForPreferenceDataset(tokenizer=tokenizer, max_length=args.max_length)
|
||||
train_dataloader = setup_distributed_dataloader(
|
||||
|
||||
train_dataloader = plugin.prepare_dataloader(
|
||||
dataset=train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=data_collator,
|
||||
use_tp=args.tp > 1,
|
||||
distributed_sampler_cls=StatefulDistributedSampler,
|
||||
)
|
||||
|
||||
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
|
||||
|
@ -290,6 +294,12 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
|
||||
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
|
||||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--pp", type=int, default=1)
|
||||
parser.add_argument("--sp", type=int, default=1)
|
||||
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
|
||||
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
|
||||
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
|
||||
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
|
||||
parser.add_argument("--pretrain", type=str, default=None)
|
||||
parser.add_argument("--model_type", type=str, default=None)
|
||||
parser.add_argument("--tokenizer_dir", type=str, default=None)
|
||||
|
|
|
@ -12,7 +12,6 @@ from coati.dataset import (
|
|||
StatefulDistributedSampler,
|
||||
load_tokenized_dataset,
|
||||
setup_conversation_template,
|
||||
setup_distributed_dataloader,
|
||||
)
|
||||
from coati.models import Critic, RewardModel, convert_to_lora_module, disable_dropout
|
||||
from coati.trainer import PPOTrainer
|
||||
|
@ -26,6 +25,7 @@ from colossalai.cluster import DistCoordinator
|
|||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.shardformer.policies.auto_policy import get_autopolicy
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
@ -51,7 +51,6 @@ def train(args):
|
|||
# )
|
||||
|
||||
init_ctx = nullcontext()
|
||||
booster_policy = None
|
||||
with init_ctx:
|
||||
if args.use_flash_attn:
|
||||
actor = AutoModelForCausalLM.from_pretrained(
|
||||
|
@ -86,32 +85,6 @@ def train(args):
|
|||
disable_dropout(actor)
|
||||
disable_dropout(critic)
|
||||
|
||||
if args.tp > 1:
|
||||
if reward_model.model.config.architectures[0] != critic.model.config.architectures[0]:
|
||||
raise ValueError("Reward model and critic model must have the same architecture")
|
||||
if reward_model.model.config.architectures[0] == "BloomForCausalLM":
|
||||
from colossalai.shardformer.policies.bloom import BloomPolicy
|
||||
|
||||
booster_policy = BloomPolicy()
|
||||
elif reward_model.model.config.architectures[0] == "LlamaForCausalLM":
|
||||
from colossalai.shardformer.policies.llama import LlamaPolicy
|
||||
|
||||
booster_policy = LlamaPolicy()
|
||||
elif reward_model.model.config.architectures[0] == "GPT2LMHeadModel":
|
||||
from colossalai.shardformer.policies.gpt2 import GPT2Policy
|
||||
|
||||
booster_policy = GPT2Policy()
|
||||
elif reward_model.model.config.architectures[0] == "ChatGLMModel":
|
||||
from colossalai.shardformer.policies.chatglm2 import ChatGLMPolicy
|
||||
|
||||
booster_policy = ChatGLMPolicy()
|
||||
elif reward_model.model.config.architectures[0] == "OPTForCausalLM":
|
||||
from colossalai.shardformer.policies.opt import OPTPolicy
|
||||
|
||||
booster_policy = OPTPolicy()
|
||||
else:
|
||||
raise ValueError("Unknown model architecture for policy")
|
||||
|
||||
if args.lora_rank > 0:
|
||||
actor = convert_to_lora_module(actor, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
critic = convert_to_lora_module(critic, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
|
@ -175,34 +148,6 @@ def train(args):
|
|||
adamw_mode=True,
|
||||
)
|
||||
|
||||
# configure dataset
|
||||
coordinator.print_on_master(f"Load dataset: {args.prompt_dataset}")
|
||||
mode_map = {"train": "train", "valid": "validation", "test": "test"}
|
||||
train_prompt_dataset = load_tokenized_dataset(dataset_paths=args.prompt_dataset, mode="train", mode_map=mode_map)
|
||||
data_collator = DataCollatorForPromptDataset(tokenizer=tokenizer, max_length=args.max_length - args.max_seq_len)
|
||||
train_prompt_dataloader = setup_distributed_dataloader(
|
||||
dataset=train_prompt_dataset,
|
||||
batch_size=args.experience_batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=data_collator,
|
||||
use_tp=args.tp > 1,
|
||||
)
|
||||
|
||||
if len(args.ptx_dataset) > 0:
|
||||
train_ptx_dataset = load_tokenized_dataset(dataset_paths=args.ptx_dataset, mode="train", mode_map=mode_map)
|
||||
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
|
||||
train_pretrain_dataloader = setup_distributed_dataloader(
|
||||
dataset=train_ptx_dataset,
|
||||
batch_size=args.ptx_batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=data_collator,
|
||||
use_tp=args.tp > 1,
|
||||
)
|
||||
else:
|
||||
train_pretrain_dataloader = None
|
||||
|
||||
if args.warmup_steps is None:
|
||||
args.warmup_steps = int(0.025 * args.num_episodes)
|
||||
coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
|
||||
|
@ -237,6 +182,7 @@ def train(args):
|
|||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_gradient_accumulation=True,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(
|
||||
|
@ -244,6 +190,7 @@ def train(args):
|
|||
placement_policy="auto",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
)
|
||||
elif args.plugin == "zero2":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
|
@ -261,20 +208,35 @@ def train(args):
|
|||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "3d":
|
||||
if args.use_flash_attn and (args.tp > 1 or args.pp > 1 or args.sp > 1 or args.enable_sequence_parallelism):
|
||||
logger.warning("Flash attention cannot be used with 3D parallelism for PPO training. Disabling it.")
|
||||
args.use_flash_attn = False
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=1,
|
||||
zero_stage=0,
|
||||
pp_size=args.pp,
|
||||
sp_size=args.sp,
|
||||
sequence_parallelism_mode=args.sp_mode,
|
||||
zero_stage=args.zero_stage,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
enable_sequence_parallelism=args.enable_sequence_parallelism,
|
||||
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
|
||||
parallel_output=False,
|
||||
max_norm=args.grad_clip,
|
||||
precision=args.mixed_precision,
|
||||
)
|
||||
custom_plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=1,
|
||||
zero_stage=0,
|
||||
pp_size=args.pp,
|
||||
sp_size=args.sp,
|
||||
sequence_parallelism_mode=args.sp_mode,
|
||||
zero_stage=args.zero_stage,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
enable_sequence_parallelism=args.enable_sequence_parallelism,
|
||||
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
|
||||
parallel_output=False,
|
||||
max_norm=args.grad_clip,
|
||||
precision=args.mixed_precision,
|
||||
custom_policy=booster_policy,
|
||||
custom_policy=get_autopolicy(reward_model.model),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
|
@ -282,6 +244,35 @@ def train(args):
|
|||
if args.plugin != "3d":
|
||||
custom_plugin = plugin
|
||||
|
||||
# configure dataset
|
||||
coordinator.print_on_master(f"Load dataset: {args.prompt_dataset}")
|
||||
mode_map = {"train": "train", "valid": "validation", "test": "test"}
|
||||
train_prompt_dataset = load_tokenized_dataset(dataset_paths=args.prompt_dataset, mode="train", mode_map=mode_map)
|
||||
data_collator = DataCollatorForPromptDataset(tokenizer=tokenizer, max_length=args.max_length - args.max_seq_len)
|
||||
|
||||
train_prompt_dataloader = plugin.prepare_dataloader(
|
||||
dataset=train_prompt_dataset,
|
||||
batch_size=args.experience_batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=data_collator,
|
||||
distributed_sampler_cls=StatefulDistributedSampler,
|
||||
)
|
||||
|
||||
if len(args.ptx_dataset) > 0:
|
||||
train_ptx_dataset = load_tokenized_dataset(dataset_paths=args.ptx_dataset, mode="train", mode_map=mode_map)
|
||||
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
|
||||
train_pretrain_dataloader = plugin.prepare_dataloader(
|
||||
dataset=train_ptx_dataset,
|
||||
batch_size=args.ptx_batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=data_collator,
|
||||
distributed_sampler_cls=StatefulDistributedSampler,
|
||||
)
|
||||
else:
|
||||
train_pretrain_dataloader = None
|
||||
|
||||
actor_booster = Booster(plugin=plugin)
|
||||
ref_booster = Booster(plugin=plugin)
|
||||
rm_booster = Booster(plugin=custom_plugin)
|
||||
|
@ -474,6 +465,12 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
|
||||
parser.add_argument("--tokenizer_dir", type=str, default=None)
|
||||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--pp", type=int, default=1)
|
||||
parser.add_argument("--sp", type=int, default=1)
|
||||
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
|
||||
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
|
||||
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
|
||||
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
|
||||
parser.add_argument("--pretrain", type=str, default=None)
|
||||
parser.add_argument("--rm_pretrain", type=str, default=None)
|
||||
parser.add_argument("--checkpoint_path", type=str, default=None)
|
||||
|
|
|
@ -6,12 +6,7 @@ import resource
|
|||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from coati.dataset import (
|
||||
DataCollatorForPreferenceDataset,
|
||||
StatefulDistributedSampler,
|
||||
load_tokenized_dataset,
|
||||
setup_distributed_dataloader,
|
||||
)
|
||||
from coati.dataset import DataCollatorForPreferenceDataset, StatefulDistributedSampler, load_tokenized_dataset
|
||||
from coati.models import LogExpLoss, LogSigLoss, RewardModel, convert_to_lora_module
|
||||
from coati.trainer import RewardModelTrainer
|
||||
from coati.utils import load_checkpoint
|
||||
|
@ -23,6 +18,7 @@ from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLev
|
|||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.shardformer.policies.auto_policy import get_autopolicy
|
||||
|
||||
|
||||
def train(args):
|
||||
|
@ -46,7 +42,6 @@ def train(args):
|
|||
# )
|
||||
|
||||
init_ctx = nullcontext()
|
||||
booster_policy = None
|
||||
with init_ctx:
|
||||
if args.use_flash_attn:
|
||||
model = RewardModel(
|
||||
|
@ -56,31 +51,9 @@ def train(args):
|
|||
)
|
||||
coordinator.print_on_master(msg="Flash-attention enabled successfully")
|
||||
else:
|
||||
model = RewardModel(args.pretrain)
|
||||
|
||||
if args.tp > 1:
|
||||
if model.model.config.architectures[0] == "BloomForCausalLM":
|
||||
from colossalai.shardformer.policies.bloom import BloomPolicy
|
||||
|
||||
booster_policy = BloomPolicy()
|
||||
elif model.model.config.architectures[0] == "LlamaForCausalLM":
|
||||
from colossalai.shardformer.policies.llama import LlamaPolicy
|
||||
|
||||
booster_policy = LlamaPolicy()
|
||||
elif model.model.config.architectures[0] == "GPT2LMHeadModel":
|
||||
from colossalai.shardformer.policies.gpt2 import GPT2Policy
|
||||
|
||||
booster_policy = GPT2Policy()
|
||||
elif model.model.config.architectures[0] == "ChatGLMModel":
|
||||
from colossalai.shardformer.policies.chatglm2 import ChatGLMPolicy
|
||||
|
||||
booster_policy = ChatGLMPolicy()
|
||||
elif model.model.config.architectures[0] == "OPTForCausalLM":
|
||||
from colossalai.shardformer.policies.opt import OPTPolicy
|
||||
|
||||
booster_policy = OPTPolicy()
|
||||
else:
|
||||
raise ValueError("Unknown model architecture for policy")
|
||||
model = RewardModel(
|
||||
args.pretrain,
|
||||
)
|
||||
|
||||
if args.lora_rank > 0:
|
||||
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
|
@ -100,6 +73,7 @@ def train(args):
|
|||
placement_policy="static",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
enable_gradient_accumulation=True,
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
|
@ -107,6 +81,7 @@ def train(args):
|
|||
precision=args.mixed_precision,
|
||||
placement_policy="auto",
|
||||
initial_scale=2**16,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "zero2":
|
||||
|
@ -127,11 +102,17 @@ def train(args):
|
|||
elif args.plugin == "3d":
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=1,
|
||||
zero_stage=0,
|
||||
pp_size=args.pp,
|
||||
sp_size=args.sp,
|
||||
sequence_parallelism_mode=args.sp_mode,
|
||||
zero_stage=args.zero_stage,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
enable_sequence_parallelism=args.enable_sequence_parallelism,
|
||||
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
|
||||
parallel_output=False,
|
||||
max_norm=args.grad_clip,
|
||||
precision=args.mixed_precision,
|
||||
custom_policy=booster_policy,
|
||||
custom_policy=get_autopolicy(model.model),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
|
@ -183,15 +164,15 @@ def train(args):
|
|||
mode_map = {"train": "train", "valid": "validation", "test": "test"}
|
||||
train_dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train", mode_map=mode_map)
|
||||
data_collator = DataCollatorForPreferenceDataset(tokenizer=tokenizer, max_length=args.max_length)
|
||||
train_dataloader = setup_distributed_dataloader(
|
||||
|
||||
train_dataloader = plugin.prepare_dataloader(
|
||||
dataset=train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=data_collator,
|
||||
use_tp=args.tp > 1,
|
||||
distributed_sampler_cls=StatefulDistributedSampler,
|
||||
)
|
||||
|
||||
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
|
||||
math.ceil(args.max_epochs * num_update_steps_per_epoch)
|
||||
|
||||
|
@ -307,6 +288,12 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
|
||||
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
|
||||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--pp", type=int, default=1)
|
||||
parser.add_argument("--sp", type=int, default=1)
|
||||
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
|
||||
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
|
||||
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
|
||||
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
|
||||
parser.add_argument("--pretrain", type=str, default=None)
|
||||
parser.add_argument("--tokenizer_dir", type=str, default=None)
|
||||
parser.add_argument("--dataset", nargs="+", default=[])
|
||||
|
|
|
@ -6,7 +6,7 @@ import resource
|
|||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from coati.dataset import DataCollatorForSupervisedDataset, load_tokenized_dataset, setup_distributed_dataloader
|
||||
from coati.dataset import DataCollatorForSupervisedDataset, StatefulDistributedSampler, load_tokenized_dataset
|
||||
from coati.models import convert_to_lora_module
|
||||
from coati.trainer import SFTTrainer
|
||||
from coati.utils import load_checkpoint
|
||||
|
@ -16,9 +16,12 @@ import colossalai
|
|||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
||||
def train(args):
|
||||
# check lora compatibility
|
||||
|
@ -35,6 +38,24 @@ def train(args):
|
|||
# ==============================
|
||||
# Initialize Booster
|
||||
# ==============================
|
||||
init_ctx = nullcontext()
|
||||
with init_ctx:
|
||||
if args.use_flash_attn:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.pretrain,
|
||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
attn_implementation="flash_attention_2",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.pretrain,
|
||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
if args.lora_rank > 0:
|
||||
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
|
||||
if args.plugin == "ddp":
|
||||
"""
|
||||
Default torch ddp plugin without any acceleration, for
|
||||
|
@ -47,7 +68,8 @@ def train(args):
|
|||
placement_policy="static",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_gradient_accumulation=True,
|
||||
enable_gradient_accumulation=True if args.accumulation_steps > 1 else False,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(
|
||||
|
@ -55,6 +77,7 @@ def train(args):
|
|||
placement_policy="auto",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
)
|
||||
elif args.plugin == "zero2":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
|
@ -74,11 +97,17 @@ def train(args):
|
|||
elif args.plugin == "3d":
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=1,
|
||||
zero_stage=0,
|
||||
pp_size=args.pp,
|
||||
sp_size=args.sp,
|
||||
sequence_parallelism_mode=args.sp_mode,
|
||||
zero_stage=args.zero_stage,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
enable_sequence_parallelism=args.enable_sequence_parallelism,
|
||||
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
|
||||
parallel_output=False,
|
||||
max_norm=args.grad_clip,
|
||||
precision=args.mixed_precision,
|
||||
microbatch_size=args.batch_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
|
@ -93,20 +122,6 @@ def train(args):
|
|||
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
|
||||
# )
|
||||
|
||||
init_ctx = nullcontext()
|
||||
with init_ctx:
|
||||
if args.use_flash_attn:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.pretrain,
|
||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
use_flash_attention_2=True,
|
||||
)
|
||||
coordinator.print_on_master(msg="Flash-attention enabled successfully")
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(args.pretrain)
|
||||
if args.lora_rank > 0:
|
||||
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
|
||||
if args.grad_checkpoint and args.lora_rank == 0:
|
||||
# lora layers are not supported by gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
|
@ -131,6 +146,7 @@ def train(args):
|
|||
|
||||
tokenizer.add_bos_token = False
|
||||
tokenizer.add_eos_token = False
|
||||
tokenizer.padding_side = "right"
|
||||
|
||||
coordinator.print_on_master(f"Configuration file will be saved at: {args.config_file}")
|
||||
coordinator.print_on_master(f"Model checkpoint will be saved at: {args.save_path}")
|
||||
|
@ -150,13 +166,14 @@ def train(args):
|
|||
)
|
||||
dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train")
|
||||
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_len)
|
||||
train_dataloader = setup_distributed_dataloader(
|
||||
|
||||
train_dataloader = plugin.prepare_dataloader(
|
||||
dataset=dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=data_collator,
|
||||
use_tp=args.tp > 1,
|
||||
distributed_sampler_cls=StatefulDistributedSampler,
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
|
@ -185,7 +202,6 @@ def train(args):
|
|||
lr_scheduler=lr_scheduler,
|
||||
dataloader=train_dataloader,
|
||||
)
|
||||
# model = model.to(get_current_device())
|
||||
torch.set_default_dtype(torch.float)
|
||||
|
||||
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
||||
|
@ -255,7 +271,7 @@ def train(args):
|
|||
# save model checkpoint after fitting on only rank0
|
||||
coordinator.print_on_master("Start saving final model checkpoint")
|
||||
|
||||
booster.save_model(model, os.path.join(args.save_path, "modeling"), shard=True)
|
||||
# booster.save_model(model, os.path.join(args.save_path, "modeling"), shard=True)
|
||||
coordinator.print_on_master(f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_path}")
|
||||
|
||||
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||
|
@ -270,13 +286,19 @@ if __name__ == "__main__":
|
|||
"--plugin",
|
||||
type=str,
|
||||
default="gemini",
|
||||
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d", "ddp"],
|
||||
choices=["gemini", "gemini_auto", "3d", "ddp", "zero2_cpu", "zero2"],
|
||||
help="Choose which plugin to use",
|
||||
)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
|
||||
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
|
||||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--pp", type=int, default=1)
|
||||
parser.add_argument("--sp", type=int, default=1)
|
||||
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
|
||||
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
|
||||
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
|
||||
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
|
||||
parser.add_argument("--pretrain", type=str, default=None)
|
||||
parser.add_argument("--tokenizer_dir", type=str, default=None)
|
||||
parser.add_argument("--dataset", nargs="+", default=[])
|
||||
|
@ -287,7 +309,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--max_epochs", type=int, default=3)
|
||||
parser.add_argument("--batch_size", type=int, default=4)
|
||||
parser.add_argument("--max_len", type=int, default=512)
|
||||
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||
parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument(
|
||||
"--lora_train_bias",
|
||||
|
|
|
@ -15,7 +15,7 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
|||
|
||||
|
||||
# export CUDA_VISIBLE_DEVICES=4,5,6
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 4
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 2
|
||||
PROJECT_NAME="sft"
|
||||
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
|
||||
PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs
|
||||
|
@ -40,8 +40,10 @@ FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
|
|||
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
|
||||
CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
|
||||
|
||||
echo $(which colossalai)
|
||||
echo $(which python)
|
||||
# the real batch size for gradient descent is number_of_node_in_hostfile * nproc_per_node * train_batch_size
|
||||
colossalai run --nproc_per_node 4 --master_port 31312 --hostfile ./hostfile train_sft.py \
|
||||
colossalai run --nproc_per_node 2 --master_port 31312 --hostfile ./hostfile train_sft.py \
|
||||
--pretrain $PRETRAINED_MODEL_PATH \
|
||||
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
|
||||
--save_interval 4000 \
|
||||
|
@ -49,11 +51,15 @@ colossalai run --nproc_per_node 4 --master_port 31312 --hostfile ./hostfile trai
|
|||
--save_path $SAVE_DIR \
|
||||
--config_file $CONFIG_FILE \
|
||||
--lora_rank 0 \
|
||||
--plugin zero2 \
|
||||
--batch_size 8 \
|
||||
--max_epochs 1 \
|
||||
--plugin 3d \
|
||||
--tp 2 \
|
||||
--pp 1 \
|
||||
--zero_stage 0 \
|
||||
--batch_size 2 \
|
||||
--max_epochs 3 \
|
||||
--accumulation_steps 1 \
|
||||
--lr 2e-5 \
|
||||
--max_len 2048 \
|
||||
--lr 5e-5 \
|
||||
--max_len 400 \
|
||||
--grad_checkpoint \
|
||||
--use_wandb
|
||||
--use_wandb \
|
||||
--use_flash_attn
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
transformers==4.34.1
|
||||
huggingface_hub==0.17.3
|
||||
transformers>=4.36.2
|
||||
tqdm
|
||||
datasets
|
||||
datasets==2.14.7
|
||||
loralib
|
||||
colossalai>=0.3.6
|
||||
colossalai>=0.3.7
|
||||
torch>=1.12.1
|
||||
langchain
|
||||
tokenizers
|
||||
|
|
|
@ -4,5 +4,6 @@
|
|||
"stop_ids": [
|
||||
29871,
|
||||
2
|
||||
]
|
||||
],
|
||||
"end_of_assistant": "</s>"
|
||||
}
|
||||
|
|
|
@ -6,7 +6,8 @@ TEST_DATA_DIR=$BASE_DIR/tests/test_data
|
|||
DATA_SAVE_PATH=$BASE_TEMP_DIR/tests
|
||||
CONFIG_DIR=$BASE_DIR/config
|
||||
|
||||
MODELS=("colossal-llama2" "llama2" "zephyr" "mistral" "chatGLM2" "Qwen" "Vicuna" "Yi")
|
||||
# MODELS=("colossal-llama2" "llama2" "mistral" "chatGLM2" "chatGLM3" "deepseek" "Yi" "baichuan") # for local test
|
||||
MODELS=("colossal-llama2" "llama2" "chatGLM2" "chatGLM3" "deepseek" "Yi")
|
||||
|
||||
get_pretrain() {
|
||||
local model=$1
|
||||
|
@ -14,27 +15,51 @@ get_pretrain() {
|
|||
echo "hpcai-tech/Colossal-LLaMA-2-7b-base"
|
||||
elif [[ $model == "llama2" ]]; then
|
||||
echo "hf-internal-testing/llama-tokenizer"
|
||||
elif [[ $model == "zephyr" ]]; then
|
||||
echo "HuggingFaceH4/zephyr-7b-beta"
|
||||
elif [[ $model == "phi" ]]; then
|
||||
echo "microsoft/phi-2"
|
||||
elif [[ $model == "mistral" ]]; then
|
||||
echo "mistralai/Mistral-7B-Instruct-v0.2"
|
||||
echo "mistralai/Mistral-7B-Instruct-v0.3"
|
||||
elif [[ $model == "chatGLM2" ]]; then
|
||||
echo "THUDM/chatglm2-6b"
|
||||
elif [[ $model == "Qwen" ]]; then
|
||||
echo "Qwen/Qwen-7B-Chat"
|
||||
elif [[ $model == "Vicuna" ]]; then
|
||||
echo "lmsys/vicuna-7b-v1.5"
|
||||
elif [[ $model == "chatGLM3" ]]; then
|
||||
echo "THUDM/chatglm3-6b"
|
||||
elif [[ $model == "deepseek" ]]; then
|
||||
echo "deepseek-ai/DeepSeek-V2-Lite"
|
||||
elif [[ $model == "Yi" ]]; then
|
||||
echo "01-ai/Yi-6B-Chat"
|
||||
echo "01-ai/Yi-1.5-9B-Chat"
|
||||
elif [[ $model == "baichuan" ]]; then
|
||||
echo "baichuan-inc/Baichuan2-13B-Chat"
|
||||
else
|
||||
echo "Unknown model $model"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
|
||||
get_conversation_template_config() {
|
||||
local model=$1
|
||||
echo "$CONFIG_DIR/conversation_template/$model.json"
|
||||
if [[ $model == "colossal-llama2" ]]; then
|
||||
echo "$CONFIG_DIR/conversation_template/colossal-llama2.json"
|
||||
elif [[ $model == "llama2" ]]; then
|
||||
echo "$CONFIG_DIR/conversation_template/llama2.json"
|
||||
elif [[ $model == "deepseek" ]]; then
|
||||
echo "$CONFIG_DIR/conversation_template/deepseek-ai_DeepSeek-V2-Lite.json"
|
||||
elif [[ $model == "mistral" ]]; then
|
||||
echo "$CONFIG_DIR/conversation_template/mistralai_Mixtral-8x7B-Instruct-v0.1.json"
|
||||
elif [[ $model == "chatGLM2" ]]; then
|
||||
echo "$CONFIG_DIR/conversation_template/THUDM_chatglm2-6b.json"
|
||||
elif [[ $model == "chatGLM3" ]]; then
|
||||
echo "$CONFIG_DIR/conversation_template/THUDM_chatglm3-6b.json"
|
||||
elif [[ $model == "phi" ]]; then
|
||||
echo "$CONFIG_DIR/conversation_template/microsoft_phi-2.json"
|
||||
elif [[ $model == "Yi" ]]; then
|
||||
echo "$CONFIG_DIR/conversation_template/01-ai_Yi-1.5-9B-Chat.json"
|
||||
elif [[ $model == "baichuan" ]]; then
|
||||
echo "$CONFIG_DIR/conversation_template/baichuan-inc_Baichuan2-13B-Chat.json"
|
||||
else
|
||||
echo "Unknown model $model"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
# Test SFT data Preparation
|
||||
|
|
|
@ -30,7 +30,8 @@ MODEL_SAVE_PATH=$TEMP_DIR/rlhf_models
|
|||
MODELS_DIR=$TEMP_DIR/models_config
|
||||
# Skip those tests due to CI tests timeout
|
||||
MODELS=('llama')
|
||||
PLUGINS=('gemini' 'gemini_auto' 'zero2' 'zero2_cpu' '3d')
|
||||
ADVANCED_PLUGINS=('sp_split_gather' 'sp_ring' 'sp_all_to_all' 'tp_zero2' '3d' 'gemini' 'gemini_auto' 'zero2' 'zero2_cpu') # pp is still buggy
|
||||
PLUGINS=('3d' 'gemini' 'gemini_auto' 'zero2' 'zero2_cpu')
|
||||
LORA_RANK=('0') # skip to reduce CI execution time, can pass all locally
|
||||
|
||||
export OMP_NUM_THREADS=8
|
||||
|
@ -80,6 +81,8 @@ random_choice() {
|
|||
}
|
||||
|
||||
|
||||
|
||||
|
||||
echo "[Test]: testing sft ..."
|
||||
|
||||
SKIPPED_TESTS=(
|
||||
|
@ -91,7 +94,7 @@ SKIPPED_TESTS=(
|
|||
GRAD_CKPTS=('--grad_checkpoint')
|
||||
for lora_rank in ${LORA_RANK[@]}; do
|
||||
for model in ${MODELS[@]}; do
|
||||
for plugin in ${PLUGINS[@]}; do
|
||||
for plugin in ${ADVANCED_PLUGINS[@]}; do
|
||||
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then
|
||||
echo "[Test]: Skipped $model-$plugin-$lora_rank"
|
||||
continue
|
||||
|
@ -104,10 +107,56 @@ for lora_rank in ${LORA_RANK[@]}; do
|
|||
grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
|
||||
tp='1'
|
||||
bs='2'
|
||||
pp='1'
|
||||
zero_stage='0'
|
||||
sp='1'
|
||||
sp_mode='split_gather'
|
||||
enable_sequence_parallelism=''
|
||||
if [[ $plugin == "3d" ]]; then
|
||||
tp='4'
|
||||
bs='8'
|
||||
fi
|
||||
if [[ $plugin == "tp_zero2" ]]; then
|
||||
tp='4'
|
||||
bs='8'
|
||||
zero_stage='2'
|
||||
plugin='3d'
|
||||
fi
|
||||
if [[ $plugin == "tp_pp" ]]; then
|
||||
tp='2'
|
||||
bs='8'
|
||||
pp='2'
|
||||
plugin='3d'
|
||||
fi
|
||||
if [[ $plugin == "pp" ]]; then
|
||||
bs='8'
|
||||
pp='4'
|
||||
plugin='3d'
|
||||
fi
|
||||
if [[ $plugin == "sp_split_gather" ]]; then
|
||||
enable_sequence_parallelism='--enable_sequence_parallelism'
|
||||
sp_mode='split_gather'
|
||||
tp='4'
|
||||
sp='1'
|
||||
bs='8'
|
||||
plugin='3d'
|
||||
fi
|
||||
if [[ $plugin == "sp_ring" ]]; then
|
||||
enable_sequence_parallelism='--enable_sequence_parallelism'
|
||||
sp_mode='ring'
|
||||
tp='4'
|
||||
sp='1'
|
||||
bs='8'
|
||||
plugin='3d'
|
||||
fi
|
||||
if [[ $plugin == "sp_all_to_all" ]]; then
|
||||
enable_sequence_parallelism='--enable_sequence_parallelism'
|
||||
sp_mode='all_to_all'
|
||||
tp='1'
|
||||
sp='4'
|
||||
bs='8'
|
||||
plugin='3d'
|
||||
fi
|
||||
grad_accu='2'
|
||||
# Check if the plugin is either "gemini_auto" or "gemini" and set grad_accu to '1'
|
||||
if [[ $plugin == "gemini_auto" ]]; then
|
||||
|
@ -132,6 +181,11 @@ for lora_rank in ${LORA_RANK[@]}; do
|
|||
--max_epochs 1 \
|
||||
--accumulation_steps $grad_accu \
|
||||
--tp $tp \
|
||||
--pp $pp \
|
||||
--zero_stage $zero_stage \
|
||||
--sp $sp \
|
||||
--sp_mode $sp_mode \
|
||||
$enable_sequence_parallelism \
|
||||
--lr 2e-5 \
|
||||
$grad_ckpt \
|
||||
--max_len 400 \
|
||||
|
@ -226,8 +280,8 @@ echo "[Test]: testing ppo ..."
|
|||
|
||||
|
||||
SKIPPED_TESTS=(
|
||||
llama-3d-20 # 3d plugin doesn't support lora
|
||||
llama-gemini-20 # gemini doesn't support lora
|
||||
llama-3d # 3d plugin doesn't support lora
|
||||
llama-gemini # gemini doesn't support lora
|
||||
)
|
||||
|
||||
GRAD_CKPTS=('--grad_checkpoint')
|
||||
|
@ -304,7 +358,7 @@ for lora_rank in ${LORA_RANK[@]}; do
|
|||
$grad_ckpt \
|
||||
--max_len 400 \
|
||||
--max_seq_len 10 \
|
||||
--use_flash_attn
|
||||
# --use_flash_attn
|
||||
passed=$?
|
||||
if [ $passed -eq 0 ]; then
|
||||
rm -rf $MODEL_SAVE_PATH/*
|
||||
|
|
Loading…
Reference in New Issue