mirror of https://github.com/hpcaitech/ColossalAI
[chat] fix bugs and add unit tests (#4213)
* style: rename replay buffer Experience replay is typically for off policy algorithms. Use this name in PPO maybe misleading. * fix: fix wrong zero2 default arg * test: update experience tests * style: rename zero_pad fn * fix: defer init in CycledDataLoader * test: add benchmark test * style: rename internal fn of generation * style: rename internal fn of lora * fix: remove unused loss fn * fix: remove unused utils fn * refactor: remove generate_with_actor fn * fix: fix type annotation * test: add models tests * fix: skip llama due to long execution time * style: modify dataset * style: apply formatter * perf: update reward dataset * fix: fix wrong IGNORE_INDEX in sft dataset * fix: remove DataCollatorForSupervisedDataset * test: add dataset tests * style: apply formatter * style: rename test_ci to test_train * feat: add llama in inference * test: add inference tests * test: change test scripts directory * fix: update ci * fix: fix typo * fix: skip llama due to oom * fix: fix file mod * style: apply formatter * refactor: remove duplicated llama_gptq * style: apply formatter * to: update rm test * feat: add tokenizer arg * feat: add download model script * test: update train tests * fix: modify gemini load and save pretrained * test: update checkpoint io test * to: modify nproc_per_node * fix: do not remove existing dir * fix: modify save path * test: add random choice * fix: fix sft path * fix: enlarge nproc_per_node to avoid oom * fix: add num_retry * fix: make lora config of rm and critic consistent * fix: add warning about lora weights * fix: skip some gpt2 tests * fix: remove grad ckpt in rm and critic due to errors * refactor: directly use Actor in train_sft * test: add more arguments * fix: disable grad ckpt when using lora * fix: fix save_pretrained and related tests * test: enable zero2 tests * revert: remove useless fn * style: polish code * test: modify test argspull/4386/head
parent
16bf4c0221
commit
da4f7b855f
|
@ -43,7 +43,9 @@ jobs:
|
|||
run: |
|
||||
cd applications/Chat
|
||||
rm -rf ~/.cache/colossalai
|
||||
./examples/test_ci.sh
|
||||
./tests/test_inference.sh
|
||||
./tests/test_benchmarks.sh
|
||||
./tests/test_train.sh
|
||||
env:
|
||||
NCCL_SHM_DISABLE: 1
|
||||
MAX_JOBS: 8
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
from .prompt_dataset import PromptDataset
|
||||
from .reward_dataset import HhRlhfDataset, RmStaticDataset
|
||||
from .sft_dataset import DataCollatorForSupervisedDataset, SFTDataset, SupervisedDataset
|
||||
from .sft_dataset import SFTDataset, SupervisedDataset
|
||||
from .utils import is_rank_0
|
||||
|
||||
__all__ = [
|
||||
'RmStaticDataset', 'HhRlhfDataset', 'is_rank_0', 'SFTDataset', 'SupervisedDataset',
|
||||
'DataCollatorForSupervisedDataset', 'PromptDataset'
|
||||
'RmStaticDataset', 'HhRlhfDataset',
|
||||
'SFTDataset', 'SupervisedDataset',
|
||||
'PromptDataset', 'is_rank_0',
|
||||
]
|
||||
|
|
|
@ -1,20 +1,13 @@
|
|||
import copy
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Dict, Sequence
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from .utils import is_rank_0, jload
|
||||
|
||||
logger = get_dist_logger()
|
||||
from .utils import jload
|
||||
|
||||
|
||||
class PromptDataset(Dataset):
|
||||
|
@ -27,12 +20,13 @@ class PromptDataset(Dataset):
|
|||
max_length: int = 96):
|
||||
super(PromptDataset, self).__init__()
|
||||
self.keyed_prompt = defaultdict(list)
|
||||
logger.info("Loading data...")
|
||||
self.logger = get_dist_logger()
|
||||
self.logger.info("Loading data...")
|
||||
list_data_dict = jload(data_path)
|
||||
logger.info(f"Loaded {len(list_data_dict)} examples.")
|
||||
self.logger.info(f"Loaded {len(list_data_dict)} examples.")
|
||||
|
||||
if max_datasets_size is not None:
|
||||
logger.info(f"Limiting dataset to {max_datasets_size} examples.")
|
||||
self.logger.info(f"Limiting dataset to {max_datasets_size} examples.")
|
||||
list_data_dict = list_data_dict[:max_datasets_size]
|
||||
|
||||
instructions = [data_dict["instruction"] for data_dict in list_data_dict]
|
||||
|
|
|
@ -20,44 +20,44 @@ class RmStaticDataset(Dataset):
|
|||
|
||||
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
|
||||
super().__init__()
|
||||
self.chosen = []
|
||||
self.reject = []
|
||||
if special_token is None:
|
||||
self.end_token = tokenizer.eos_token
|
||||
else:
|
||||
self.end_token = special_token
|
||||
for data in tqdm(dataset, disable=not is_rank_0()):
|
||||
prompt = data['prompt']
|
||||
self.end_token = tokenizer.eos_token \
|
||||
if special_token is None else special_token
|
||||
|
||||
chosen = prompt + data['chosen'] + self.end_token
|
||||
chosen_token = tokenizer(chosen,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt")
|
||||
self.chosen.append({
|
||||
"input_ids": chosen_token['input_ids'],
|
||||
"attention_mask": chosen_token['attention_mask']
|
||||
})
|
||||
chosen = [
|
||||
data["prompt"] + data["chosen"] + self.end_token
|
||||
for data in tqdm(dataset, disable=not is_rank_0())
|
||||
]
|
||||
chosen_token = tokenizer(chosen,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt")
|
||||
self.chosen = {
|
||||
"input_ids": chosen_token["input_ids"],
|
||||
"attention_mask": chosen_token["attention_mask"]
|
||||
}
|
||||
|
||||
reject = prompt + data['rejected'] + self.end_token
|
||||
reject_token = tokenizer(reject,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt")
|
||||
self.reject.append({
|
||||
"input_ids": reject_token['input_ids'],
|
||||
"attention_mask": reject_token['attention_mask']
|
||||
})
|
||||
reject = [
|
||||
data["prompt"] + data["rejected"] + self.end_token
|
||||
for data in tqdm(dataset, disable=not is_rank_0())
|
||||
]
|
||||
reject_token = tokenizer(reject,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt")
|
||||
self.reject = {
|
||||
"input_ids": reject_token["input_ids"],
|
||||
"attention_mask": reject_token["attention_mask"]
|
||||
}
|
||||
|
||||
def __len__(self):
|
||||
length = len(self.chosen)
|
||||
length = self.chosen["input_ids"].shape[0]
|
||||
return length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][
|
||||
"input_ids"], self.reject[idx]["attention_mask"]
|
||||
return self.chosen["input_ids"][idx], self.chosen["attention_mask"][idx], \
|
||||
self.reject["input_ids"][idx], self.reject["attention_mask"][idx]
|
||||
|
||||
|
||||
# Anthropic/hh-rlhf
|
||||
|
@ -74,39 +74,41 @@ class HhRlhfDataset(Dataset):
|
|||
|
||||
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
|
||||
super().__init__()
|
||||
self.chosen = []
|
||||
self.reject = []
|
||||
if special_token is None:
|
||||
self.end_token = tokenizer.eos_token
|
||||
else:
|
||||
self.end_token = special_token
|
||||
for data in tqdm(dataset, disable=not is_rank_0()):
|
||||
chosen = data['chosen'] + self.end_token
|
||||
chosen_token = tokenizer(chosen,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt")
|
||||
self.chosen.append({
|
||||
"input_ids": chosen_token['input_ids'],
|
||||
"attention_mask": chosen_token['attention_mask']
|
||||
})
|
||||
self.end_token = tokenizer.eos_token \
|
||||
if special_token is None else special_token
|
||||
|
||||
reject = data['rejected'] + self.end_token
|
||||
reject_token = tokenizer(reject,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt")
|
||||
self.reject.append({
|
||||
"input_ids": reject_token['input_ids'],
|
||||
"attention_mask": reject_token['attention_mask']
|
||||
})
|
||||
chosen = [
|
||||
data["chosen"] + self.end_token
|
||||
for data in tqdm(dataset, disable=not is_rank_0())
|
||||
]
|
||||
chosen_token = tokenizer(chosen,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt")
|
||||
self.chosen = {
|
||||
"input_ids": chosen_token["input_ids"],
|
||||
"attention_mask": chosen_token["attention_mask"]
|
||||
}
|
||||
|
||||
reject = [
|
||||
data["rejected"] + self.end_token
|
||||
for data in tqdm(dataset, disable=not is_rank_0())
|
||||
]
|
||||
reject_token = tokenizer(reject,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt")
|
||||
self.reject = {
|
||||
"input_ids": reject_token["input_ids"],
|
||||
"attention_mask": reject_token["attention_mask"]
|
||||
}
|
||||
|
||||
def __len__(self):
|
||||
length = len(self.chosen)
|
||||
length = self.chosen["input_ids"].shape[0]
|
||||
return length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][
|
||||
"input_ids"], self.reject[idx]["attention_mask"]
|
||||
return self.chosen["input_ids"][idx], self.chosen["attention_mask"][idx], \
|
||||
self.reject["input_ids"][idx], self.reject["attention_mask"][idx]
|
||||
|
|
|
@ -13,44 +13,64 @@
|
|||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import random
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Dict, List, Sequence, Tuple
|
||||
from typing import Dict, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from .conversation import default_conversation
|
||||
from .utils import is_rank_0, jload
|
||||
|
||||
# The following is a template prompt for a 4-round conversation.
|
||||
"""
|
||||
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
|
||||
|
||||
Human: xxx</s>Assistant: xxx</s>Human: xxx</s>Assistant: xxx</s>Human: xxx</s>Assistant: xxx</s>Human: xxx</s>Assistant: xxx</s>
|
||||
"""
|
||||
# Please note that we only calculate loss on assistant's answer tokens.
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
DEFAULT_EOS_TOKEN = "</s>"
|
||||
PROMPT_DICT = {
|
||||
"prompt_input":
|
||||
("Below is an instruction that describes a task, paired with an input that provides further context. "
|
||||
"Write a response that appropriately completes the request.\n\n"
|
||||
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"),
|
||||
"prompt_input": ("Below is an instruction that describes a task, paired with an input that provides further context. "
|
||||
"Write a response that appropriately completes the request.\n\n"
|
||||
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"),
|
||||
"prompt_no_input": ("Below is an instruction that describes a task. "
|
||||
"Write a response that appropriately completes the request.\n\n"
|
||||
"### Instruction:\n{instruction}\n\n### Response:"),
|
||||
}
|
||||
|
||||
|
||||
def _preprocess(sources: Sequence[str],
|
||||
targets: Sequence[str],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
max_length: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Preprocess the data by tokenizing."""
|
||||
sequences = [s + t for s, t in zip(sources, targets)]
|
||||
sequences_token = tokenizer(sequences,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt")
|
||||
sources_token = tokenizer(sources,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt")
|
||||
|
||||
labels = copy.deepcopy(sequences_token["input_ids"])
|
||||
for i in range(labels.shape[0]):
|
||||
source_len = sources_token["attention_mask"][i].sum().item()
|
||||
pad_len = max_length - sequences_token["attention_mask"][i].sum().item()
|
||||
if tokenizer.padding_side == "right":
|
||||
# |prompt|completion|eos|pad|
|
||||
labels[i][:source_len] = IGNORE_INDEX
|
||||
elif tokenizer.padding_side == "left":
|
||||
# |pad|prompt|completion|eos|
|
||||
labels[i][pad_len:pad_len + source_len] = IGNORE_INDEX
|
||||
else:
|
||||
raise RuntimeError()
|
||||
|
||||
return sequences_token["input_ids"], labels, sequences_token["attention_mask"]
|
||||
|
||||
|
||||
class SFTDataset(Dataset):
|
||||
"""
|
||||
Dataset for sft model
|
||||
|
@ -61,115 +81,31 @@ class SFTDataset(Dataset):
|
|||
max_length: max length of input
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, tokenizer: Callable, max_length: int = 512) -> None:
|
||||
def __init__(self,
|
||||
dataset: Dict,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
max_length: int = 512
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.input_ids = []
|
||||
|
||||
for data in tqdm(dataset, disable=not is_rank_0()):
|
||||
prompt = data['prompt'] + data['completion'] + tokenizer.eos_token
|
||||
prompt_token = tokenizer(prompt,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt")
|
||||
sources = [data["prompt"] for data in dataset]
|
||||
targets = [
|
||||
data["completion"] + tokenizer.eos_token
|
||||
for data in tqdm(dataset, disable=not is_rank_0())
|
||||
]
|
||||
|
||||
self.input_ids.append(prompt_token['input_ids'][0])
|
||||
self.labels = copy.deepcopy(self.input_ids)
|
||||
self.input_ids, self.labels, self.attention_mask = \
|
||||
_preprocess(sources, targets, tokenizer, max_length)
|
||||
|
||||
def __len__(self):
|
||||
length = len(self.input_ids)
|
||||
length = self.input_ids.shape[0]
|
||||
return length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])
|
||||
|
||||
|
||||
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer,
|
||||
max_length: int) -> Dict[str, torch.Tensor]:
|
||||
"""Tokenize a list of strings."""
|
||||
tokenized_list = tokenizer(strings, return_tensors="pt", padding="longest", max_length=max_length, truncation=True)
|
||||
input_ids = labels = tokenized_list["input_ids"]
|
||||
input_ids_lens = labels_lens = \
|
||||
tokenized_list["input_ids"].ne(tokenizer.pad_token_id).sum(dim=-1)
|
||||
return dict(
|
||||
input_ids=input_ids,
|
||||
labels=labels,
|
||||
input_ids_lens=input_ids_lens,
|
||||
labels_lens=labels_lens,
|
||||
)
|
||||
|
||||
|
||||
def preprocess(
|
||||
sources: Sequence[str],
|
||||
targets: Sequence[str],
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
max_length: int,
|
||||
) -> Dict:
|
||||
"""Preprocess the data by tokenizing."""
|
||||
examples = [s + t for s, t in zip(sources, targets)]
|
||||
examples_tokenized, sources_tokenized = [
|
||||
_tokenize_fn(strings, tokenizer, max_length) for strings in (examples, sources)
|
||||
]
|
||||
input_ids = examples_tokenized["input_ids"]
|
||||
labels = copy.deepcopy(input_ids)
|
||||
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
|
||||
label[:source_len] = IGNORE_INDEX
|
||||
return dict(input_ids=input_ids, labels=labels)
|
||||
|
||||
|
||||
def preprocess_conversation(sources: List[List[Dict]], tokenizer: transformers.PreTrainedTokenizer,
|
||||
max_length: int) -> Dict:
|
||||
"""Preprocess the conversation data by tokenizing."""
|
||||
conversations = []
|
||||
intermediates = []
|
||||
for source in sources:
|
||||
header = f"{default_conversation.system}"
|
||||
conversation, intermediate = _add_speaker_and_signal(header, source)
|
||||
conversations.append(conversation)
|
||||
intermediates.append(intermediate)
|
||||
|
||||
conversations_tokenized = _tokenize_fn(conversations, tokenizer, max_length)
|
||||
input_ids = conversations_tokenized["input_ids"]
|
||||
targets = copy.deepcopy(input_ids)
|
||||
|
||||
assert len(targets) == len(intermediates)
|
||||
for target, inters in zip(targets, intermediates):
|
||||
mask = torch.zeros_like(target, dtype=torch.bool)
|
||||
for inter in inters:
|
||||
tokenized = _tokenize_fn(inter, tokenizer, max_length)
|
||||
|
||||
start_idx = tokenized["input_ids"][0].size(0) - 1
|
||||
end_idx = tokenized["input_ids"][1].size(0)
|
||||
|
||||
mask[start_idx:end_idx] = True
|
||||
target[~mask] = IGNORE_INDEX
|
||||
|
||||
return dict(input_ids=input_ids, labels=targets)
|
||||
|
||||
|
||||
def _add_speaker_and_signal(header: str,
|
||||
source: List[Dict],
|
||||
get_conversation: bool = True) -> Tuple[str, List[List[str]]]:
|
||||
END_SIGNAL = DEFAULT_EOS_TOKEN
|
||||
conversation = header
|
||||
intermediate = []
|
||||
for sentence in source:
|
||||
from_str = sentence["from"]
|
||||
if from_str.lower() == "human":
|
||||
from_str = default_conversation.roles[0]
|
||||
elif from_str.lower() == "gpt":
|
||||
from_str = default_conversation.roles[1]
|
||||
else:
|
||||
from_str = 'unknown'
|
||||
|
||||
value = from_str + ": " + sentence["value"] + END_SIGNAL
|
||||
if sentence["from"].lower() == "gpt":
|
||||
start = conversation + from_str + ": "
|
||||
end = conversation + value
|
||||
intermediate.append([start, end])
|
||||
if get_conversation:
|
||||
conversation += value
|
||||
return conversation, intermediate
|
||||
return dict(input_ids=self.input_ids[idx],
|
||||
labels=self.labels[idx],
|
||||
attention_mask=self.attention_mask[idx])
|
||||
|
||||
|
||||
class SupervisedDataset(Dataset):
|
||||
|
@ -177,10 +113,10 @@ class SupervisedDataset(Dataset):
|
|||
|
||||
def __init__(self,
|
||||
data_path: str,
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
max_datasets_size: int = None,
|
||||
max_length: int = 512):
|
||||
super(SupervisedDataset, self).__init__()
|
||||
super().__init__()
|
||||
logger.info("Loading data...")
|
||||
list_data_dict = jload(data_path)
|
||||
logger.info(f"Loaded {len(list_data_dict)} examples.")
|
||||
|
@ -190,52 +126,25 @@ class SupervisedDataset(Dataset):
|
|||
list_data_dict = list_data_dict[:max_datasets_size]
|
||||
|
||||
logger.info("Formatting inputs...")
|
||||
if "conversations" not in list_data_dict[0]:
|
||||
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
|
||||
sources = [
|
||||
prompt_input.format_map(example)
|
||||
if example.get("input", "") != "" else prompt_no_input.format_map(example) for example in list_data_dict
|
||||
]
|
||||
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
|
||||
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
|
||||
sources = [
|
||||
prompt_input.format_map(example) if "input" in example else prompt_no_input.format_map(example)
|
||||
for example in list_data_dict
|
||||
]
|
||||
targets = [
|
||||
example['output'] + tokenizer.eos_token
|
||||
for example in list_data_dict
|
||||
]
|
||||
|
||||
if is_rank_0():
|
||||
logger.info("Tokenizing inputs... This may take some time...")
|
||||
|
||||
data_dict = preprocess(sources, targets, tokenizer, max_length)
|
||||
else:
|
||||
if is_rank_0():
|
||||
logger.info("Tokenizing inputs... This may take some time...")
|
||||
|
||||
sources = [conv["conversations"] for conv in list_data_dict]
|
||||
data_dict = preprocess_conversation(sources, tokenizer, max_length)
|
||||
|
||||
if is_rank_0():
|
||||
logger.info("Tokenizing finish.")
|
||||
|
||||
self.input_ids = data_dict["input_ids"]
|
||||
self.labels = data_dict["labels"]
|
||||
logger.info("Tokenizing inputs... This may take some time...")
|
||||
self.input_ids, self.labels, self.attention_mask = \
|
||||
_preprocess(sources, targets, tokenizer, max_length)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.input_ids)
|
||||
length = self.input_ids.shape[0]
|
||||
return length
|
||||
|
||||
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
||||
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForSupervisedDataset(object):
|
||||
"""Collate examples for supervised fine-tuning."""
|
||||
|
||||
tokenizer: transformers.PreTrainedTokenizer
|
||||
|
||||
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
||||
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
|
||||
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids,
|
||||
batch_first=True,
|
||||
padding_value=self.tokenizer.pad_token_id)
|
||||
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
|
||||
return dict(
|
||||
input_ids=input_ids,
|
||||
labels=labels,
|
||||
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
||||
)
|
||||
def __getitem__(self, idx):
|
||||
return dict(input_ids=self.input_ids[idx],
|
||||
labels=self.labels[idx],
|
||||
attention_mask=self.attention_mask[idx])
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
from .base import ExperienceBuffer
|
||||
from .naive import NaiveExperienceBuffer
|
||||
|
||||
__all__ = ['ExperienceBuffer', 'NaiveExperienceBuffer']
|
|
@ -4,8 +4,8 @@ from typing import Any
|
|||
from coati.experience_maker.base import Experience
|
||||
|
||||
|
||||
class ReplayBuffer(ABC):
|
||||
"""Replay buffer base class. It stores experience.
|
||||
class ExperienceBuffer(ABC):
|
||||
"""Experience buffer base class. It stores experience.
|
||||
|
||||
Args:
|
||||
sample_batch_size (int): Batch size when sampling.
|
|
@ -4,12 +4,12 @@ from typing import List
|
|||
import torch
|
||||
from coati.experience_maker.base import Experience
|
||||
|
||||
from .base import ReplayBuffer
|
||||
from .base import ExperienceBuffer
|
||||
from .utils import BufferItem, make_experience_batch, split_experience_batch
|
||||
|
||||
|
||||
class NaiveReplayBuffer(ReplayBuffer):
|
||||
"""Naive replay buffer class. It stores experience.
|
||||
class NaiveExperienceBuffer(ExperienceBuffer):
|
||||
"""Naive experience buffer class. It stores experience.
|
||||
|
||||
Args:
|
||||
sample_batch_size (int): Batch size when sampling.
|
|
@ -33,7 +33,8 @@ class BufferItem:
|
|||
def split_experience_batch(experience: Experience) -> List[BufferItem]:
|
||||
batch_size = experience.sequences.size(0)
|
||||
batch_kwargs = [{} for _ in range(batch_size)]
|
||||
keys = ('sequences', 'action_log_probs', 'values', 'reward', 'advantages', 'attention_mask', 'action_mask')
|
||||
keys = ('sequences', 'action_log_probs', 'values',
|
||||
'reward', 'advantages', 'attention_mask', 'action_mask')
|
||||
for key in keys:
|
||||
value = getattr(experience, key)
|
||||
if isinstance(value, torch.Tensor):
|
||||
|
@ -48,7 +49,7 @@ def split_experience_batch(experience: Experience) -> List[BufferItem]:
|
|||
return items
|
||||
|
||||
|
||||
def zero_pad_sequences(sequences: List[torch.Tensor], side: str = 'left') -> torch.Tensor:
|
||||
def _zero_pad_sequences(sequences: List[torch.Tensor], side: str = 'left') -> torch.Tensor:
|
||||
assert side in ('left', 'right')
|
||||
max_len = max(seq.size(0) for seq in sequences)
|
||||
padded_sequences = []
|
||||
|
@ -62,11 +63,12 @@ def zero_pad_sequences(sequences: List[torch.Tensor], side: str = 'left') -> tor
|
|||
def make_experience_batch(items: List[BufferItem]) -> Experience:
|
||||
kwargs = {}
|
||||
to_pad_keys = set(('action_log_probs', 'action_mask'))
|
||||
keys = ('sequences', 'action_log_probs', 'values', 'reward', 'advantages', 'attention_mask', 'action_mask')
|
||||
keys = ('sequences', 'action_log_probs', 'values',
|
||||
'reward', 'advantages', 'attention_mask', 'action_mask')
|
||||
for key in keys:
|
||||
vals = [getattr(item, key) for item in items]
|
||||
if key in to_pad_keys:
|
||||
batch_data = zero_pad_sequences(vals)
|
||||
batch_data = _zero_pad_sequences(vals)
|
||||
else:
|
||||
batch_data = torch.stack(vals, dim=0)
|
||||
kwargs[key] = batch_data
|
|
@ -1,6 +1,7 @@
|
|||
import torch
|
||||
from coati.models.generation import generate_with_actor
|
||||
from coati.models.utils import calc_action_log_probs, compute_reward, normalize
|
||||
import torch.nn.functional as F
|
||||
from coati.models.generation import generate
|
||||
from coati.models.utils import calc_action_log_probs, compute_reward
|
||||
|
||||
from .base import Experience, ExperienceMaker
|
||||
|
||||
|
@ -17,10 +18,27 @@ class NaiveExperienceMaker(ExperienceMaker):
|
|||
self.initial_model.eval()
|
||||
self.reward_model.eval()
|
||||
|
||||
sequences, attention_mask, action_mask = generate_with_actor(self.actor,
|
||||
input_ids,
|
||||
return_action_mask=True,
|
||||
**generate_kwargs)
|
||||
# generate sequences
|
||||
sequences = generate(self.actor, input_ids, **generate_kwargs)
|
||||
|
||||
# calculate auxiliary tensors
|
||||
attention_mask = None
|
||||
pad_token_id = generate_kwargs.get('pad_token_id', None)
|
||||
if pad_token_id is not None:
|
||||
attention_mask = sequences.not_equal(pad_token_id)\
|
||||
.to(dtype=torch.long, device=sequences.device)
|
||||
|
||||
input_len = input_ids.size(1)
|
||||
eos_token_id = generate_kwargs.get('eos_token_id', None)
|
||||
if eos_token_id is None:
|
||||
action_mask = torch.ones_like(sequences, dtype=torch.bool)
|
||||
else:
|
||||
# left padding may be applied, only mask action
|
||||
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
|
||||
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
|
||||
action_mask[:, :input_len] = False
|
||||
action_mask = action_mask[:, 1:]
|
||||
action_mask = action_mask[:, -(sequences.size(1) - input_len):]
|
||||
num_actions = action_mask.size(1)
|
||||
|
||||
actor_output = self.actor(sequences, attention_mask)
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
from .base import Actor, Critic, RewardModel
|
||||
from .lora import LoRAModule, convert_to_lora_module
|
||||
from .loss import LogExpLoss, LogSigLoss, PolicyLoss, PPOPtxActorLoss, ValueLoss
|
||||
from .loss import LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
|
||||
|
||||
__all__ = [
|
||||
'Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'LogSigLoss', 'LogExpLoss',
|
||||
'Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'LogSigLoss', 'LogExpLoss',
|
||||
'LoRAModule', 'convert_to_lora_module'
|
||||
]
|
||||
|
|
|
@ -14,7 +14,6 @@ class BLOOMCritic(Critic):
|
|||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (BloomConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
@ -22,7 +21,6 @@ class BLOOMCritic(Critic):
|
|||
def __init__(self,
|
||||
pretrained: str = None,
|
||||
config: Optional[BloomConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none',
|
||||
**kwargs) -> None:
|
||||
|
@ -32,7 +30,6 @@ class BLOOMCritic(Critic):
|
|||
model = BloomModel(config)
|
||||
else:
|
||||
model = BloomModel(BloomConfig())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
|
||||
|
|
|
@ -13,7 +13,6 @@ class BLOOMRM(RewardModel):
|
|||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (BloomConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
@ -21,7 +20,6 @@ class BLOOMRM(RewardModel):
|
|||
def __init__(self,
|
||||
pretrained: str = None,
|
||||
config: Optional[BloomConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
|
@ -30,8 +28,7 @@ class BLOOMRM(RewardModel):
|
|||
model = BloomModel(config)
|
||||
else:
|
||||
model = BloomModel(BloomConfig())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
||||
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1))
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
from typing import Any, Callable, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .base import Actor
|
||||
|
||||
try:
|
||||
from transformers.generation_logits_process import (
|
||||
|
@ -16,9 +16,9 @@ except ImportError:
|
|||
from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper
|
||||
|
||||
|
||||
def prepare_logits_processor(top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None) -> LogitsProcessorList:
|
||||
def _prepare_logits_processor(top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None) -> LogitsProcessorList:
|
||||
processor_list = LogitsProcessorList()
|
||||
if temperature is not None and temperature != 1.0:
|
||||
processor_list.append(TemperatureLogitsWarper(temperature))
|
||||
|
@ -37,22 +37,22 @@ def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool:
|
|||
return unfinished_sequences.max() == 0
|
||||
|
||||
|
||||
def sample(model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
max_length: int,
|
||||
early_stopping: bool = False,
|
||||
eos_token_id: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None,
|
||||
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
|
||||
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
|
||||
**model_kwargs) -> torch.Tensor:
|
||||
def _sample(model: Actor,
|
||||
input_ids: torch.Tensor,
|
||||
max_length: int,
|
||||
early_stopping: bool = False,
|
||||
eos_token_id: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None,
|
||||
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
|
||||
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
|
||||
**model_kwargs) -> torch.Tensor:
|
||||
if input_ids.size(1) >= max_length:
|
||||
return input_ids
|
||||
|
||||
logits_processor = prepare_logits_processor(top_k, top_p, temperature)
|
||||
logits_processor = _prepare_logits_processor(top_k, top_p, temperature)
|
||||
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
||||
|
||||
for _ in range(input_ids.size(1), max_length):
|
||||
|
@ -89,7 +89,8 @@ def sample(model: nn.Module,
|
|||
return input_ids
|
||||
|
||||
|
||||
def generate(model: nn.Module,
|
||||
@torch.no_grad()
|
||||
def generate(model: Actor,
|
||||
input_ids: torch.Tensor,
|
||||
max_length: int,
|
||||
num_beams: int = 1,
|
||||
|
@ -128,51 +129,19 @@ def generate(model: nn.Module,
|
|||
raise NotImplementedError
|
||||
elif is_sample_gen_mode:
|
||||
# run sample
|
||||
return sample(model,
|
||||
input_ids,
|
||||
max_length,
|
||||
early_stopping=early_stopping,
|
||||
eos_token_id=eos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
prepare_inputs_fn=prepare_inputs_fn,
|
||||
update_model_kwargs_fn=update_model_kwargs_fn,
|
||||
**model_kwargs)
|
||||
return _sample(model,
|
||||
input_ids,
|
||||
max_length,
|
||||
early_stopping=early_stopping,
|
||||
eos_token_id=eos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
prepare_inputs_fn=prepare_inputs_fn,
|
||||
update_model_kwargs_fn=update_model_kwargs_fn,
|
||||
**model_kwargs)
|
||||
elif is_beam_gen_mode:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
raise ValueError("Unsupported generation mode")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_with_actor(
|
||||
actor_model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
return_action_mask: bool = True,
|
||||
**kwargs
|
||||
) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
|
||||
"""Generate token sequence with actor model. Refer to `generate` for more details.
|
||||
"""
|
||||
# generate sequences
|
||||
sequences = generate(actor_model, input_ids, **kwargs)
|
||||
|
||||
# calculate auxiliary tensors
|
||||
attention_mask = None
|
||||
pad_token_id = kwargs.get('pad_token_id', None)
|
||||
if pad_token_id is not None:
|
||||
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
|
||||
if not return_action_mask:
|
||||
return sequences, attention_mask, None
|
||||
input_len = input_ids.size(1)
|
||||
eos_token_id = kwargs.get('eos_token_id', None)
|
||||
if eos_token_id is None:
|
||||
action_mask = torch.ones_like(sequences, dtype=torch.bool)
|
||||
else:
|
||||
# left padding may be applied, only mask action
|
||||
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
|
||||
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
|
||||
action_mask[:, :input_len] = False
|
||||
action_mask = action_mask[:, 1:]
|
||||
return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):]
|
||||
|
|
|
@ -14,7 +14,6 @@ class GPTCritic(Critic):
|
|||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (GPT2Config): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the LO-RA decomposition.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
@ -22,7 +21,6 @@ class GPTCritic(Critic):
|
|||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[GPT2Config] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none',
|
||||
**kwargs) -> None:
|
||||
|
@ -32,7 +30,6 @@ class GPTCritic(Critic):
|
|||
model = GPT2Model(config)
|
||||
else:
|
||||
model = GPT2Model(GPT2Config())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
value_head = nn.Linear(model.config.n_embd, 1)
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
|
||||
|
|
|
@ -14,7 +14,6 @@ class GPTRM(RewardModel):
|
|||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (GPT2Config): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the low-rank approximation.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
@ -22,7 +21,6 @@ class GPTRM(RewardModel):
|
|||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[GPT2Config] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
|
@ -31,8 +29,6 @@ class GPTRM(RewardModel):
|
|||
model = GPT2Model(config)
|
||||
else:
|
||||
model = GPT2Model(GPT2Config())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
value_head = nn.Linear(model.config.n_embd, 1)
|
||||
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.n_embd + 1))
|
||||
|
|
|
@ -13,7 +13,6 @@ class LlamaCritic(Critic):
|
|||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (LlamaConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
@ -21,7 +20,6 @@ class LlamaCritic(Critic):
|
|||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[LlamaConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none',
|
||||
**kwargs) -> None:
|
||||
|
@ -33,9 +31,5 @@ class LlamaCritic(Critic):
|
|||
else:
|
||||
model = LlamaModel(LlamaConfig())
|
||||
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
||||
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
|
||||
|
|
|
@ -13,7 +13,6 @@ class LlamaRM(RewardModel):
|
|||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (LlamaConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
@ -21,7 +20,6 @@ class LlamaRM(RewardModel):
|
|||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[LlamaConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
|
||||
|
@ -32,8 +30,6 @@ class LlamaRM(RewardModel):
|
|||
else:
|
||||
model = LlamaModel(LlamaConfig())
|
||||
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
||||
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1))
|
||||
|
||||
|
|
|
@ -98,18 +98,18 @@ class LoraLinear(lora.LoRALayer, nn.Module):
|
|||
return F.linear(x, T(self.weight), bias=self.bias)
|
||||
|
||||
|
||||
def lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
|
||||
def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
|
||||
assert lora_rank <= linear.in_features, f'LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})'
|
||||
lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False)
|
||||
return lora_linear
|
||||
|
||||
|
||||
def convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
|
||||
def _convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
|
||||
for name, child in module.named_children():
|
||||
if isinstance(child, nn.Linear):
|
||||
setattr(module, name, lora_linear_wrapper(child, lora_rank))
|
||||
setattr(module, name, _lora_linear_wrapper(child, lora_rank))
|
||||
else:
|
||||
convert_to_lora_recursively(child, lora_rank)
|
||||
_convert_to_lora_recursively(child, lora_rank)
|
||||
|
||||
|
||||
def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = 'none') -> nn.Module:
|
||||
|
@ -124,7 +124,7 @@ def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: s
|
|||
"""
|
||||
if lora_rank <= 0:
|
||||
return module
|
||||
convert_to_lora_recursively(module, lora_rank)
|
||||
_convert_to_lora_recursively(module, lora_rank)
|
||||
lora.mark_only_lora_as_trainable(module, lora_train_bias)
|
||||
return module
|
||||
|
||||
|
|
|
@ -68,31 +68,6 @@ class ValueLoss(nn.Module):
|
|||
return 0.5 * loss
|
||||
|
||||
|
||||
class PPOPtxActorLoss(nn.Module):
|
||||
"""
|
||||
To Do:
|
||||
|
||||
PPO-ptx Actor Loss
|
||||
"""
|
||||
|
||||
def __init__(self, policy_clip_eps: float = 0.2, pretrain_coef: float = 0.0, pretrain_loss_fn=GPTLMLoss()) -> None:
|
||||
super().__init__()
|
||||
self.pretrain_coef = pretrain_coef
|
||||
self.policy_loss_fn = PolicyLoss(clip_eps=policy_clip_eps)
|
||||
self.pretrain_loss_fn = pretrain_loss_fn
|
||||
|
||||
def forward(self,
|
||||
log_probs: torch.Tensor,
|
||||
old_log_probs: torch.Tensor,
|
||||
advantages: torch.Tensor,
|
||||
lm_logits: torch.Tensor,
|
||||
lm_input_ids: torch.Tensor,
|
||||
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
policy_loss = self.policy_loss_fn(log_probs, old_log_probs, advantages, action_mask=action_mask)
|
||||
lm_loss = self.pretrain_loss_fn(lm_logits, lm_input_ids)
|
||||
return policy_loss + self.pretrain_coef * lm_loss
|
||||
|
||||
|
||||
class LogSigLoss(nn.Module):
|
||||
"""
|
||||
Pairwise Loss for Reward Model
|
||||
|
|
|
@ -14,7 +14,6 @@ class OPTCritic(Critic):
|
|||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (OPTConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the low-rank approximation.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
@ -22,7 +21,6 @@ class OPTCritic(Critic):
|
|||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[OPTConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none',
|
||||
**kwargs) -> None:
|
||||
|
@ -32,7 +30,6 @@ class OPTCritic(Critic):
|
|||
model = OPTModel(config)
|
||||
else:
|
||||
model = OPTModel(OPTConfig())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
value_head = nn.Linear(model.config.word_embed_proj_dim, 1)
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
|
||||
|
|
|
@ -13,7 +13,6 @@ class OPTRM(RewardModel):
|
|||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (OPTConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the low-rank approximation.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
@ -21,7 +20,6 @@ class OPTRM(RewardModel):
|
|||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[OPTConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
|
@ -30,8 +28,6 @@ class OPTRM(RewardModel):
|
|||
model = OPTModel(config)
|
||||
else:
|
||||
model = OPTModel(OPTConfig())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
value_head = nn.Linear(model.config.word_embed_proj_dim, 1)
|
||||
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.word_embed_proj_dim + 1))
|
||||
|
|
|
@ -1,14 +1,12 @@
|
|||
from typing import Optional, Union
|
||||
|
||||
import loralib as lora
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def compute_approx_kl(log_probs: torch.Tensor,
|
||||
log_probs_base: torch.Tensor,
|
||||
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def _compute_approx_kl(log_probs: torch.Tensor,
|
||||
log_probs_base: torch.Tensor,
|
||||
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""
|
||||
Compute the approximate KL divergence between two distributions.
|
||||
Schulman blog: http://joschu.net/blog/kl-approx.html
|
||||
|
@ -35,12 +33,12 @@ def compute_reward(r: Union[torch.Tensor, float],
|
|||
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
if kl_coef <= 0.0:
|
||||
return r
|
||||
kl = compute_approx_kl(log_probs, log_probs_base, action_mask=action_mask)
|
||||
kl = _compute_approx_kl(log_probs, log_probs_base, action_mask=action_mask)
|
||||
reward = r - kl_coef * kl
|
||||
return reward
|
||||
|
||||
|
||||
def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
||||
def _log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
||||
log_probs = F.log_softmax(logits, dim=-1)
|
||||
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
|
||||
return log_probs_labels.squeeze(-1)
|
||||
|
@ -58,7 +56,7 @@ def calc_action_log_probs(output: torch.Tensor, sequences: torch.LongTensor, num
|
|||
torch.Tensor: Action log probs.
|
||||
"""
|
||||
logits = output['logits']
|
||||
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
|
||||
log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
|
||||
return log_probs[:, -num_actions:]
|
||||
|
||||
|
||||
|
@ -68,41 +66,3 @@ def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch
|
|||
mask_sum = mask.sum(dim=dim)
|
||||
mean = tensor / (mask_sum + 1e-8)
|
||||
return mean
|
||||
|
||||
|
||||
def masked_normalize(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1, eps: float = 1e-8) -> torch.Tensor:
|
||||
tensor = tensor * mask
|
||||
mean = masked_mean(tensor, mask, dim=dim)
|
||||
mean_centered = tensor - mean
|
||||
var = masked_mean(mean_centered**2, mask, dim=dim)
|
||||
return mean_centered * var.clamp(min=eps).rsqrt()
|
||||
|
||||
|
||||
def normalize(tensor: torch.Tensor, dim: int = 0, eps: float = 1e-8) -> torch.Tensor:
|
||||
mean = tensor.mean(dim)
|
||||
mean_centered = tensor - mean
|
||||
var = (mean_centered**2).mean(dim)
|
||||
norm = mean_centered * var.clamp(min=eps).rsqrt()
|
||||
return norm
|
||||
|
||||
|
||||
def convert_to_lora(model: nn.Module,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
lora_rank: int = 16,
|
||||
lora_alpha: int = 1,
|
||||
lora_dropout: float = 0.,
|
||||
fan_in_fan_out: bool = False,
|
||||
merge_weights: bool = True):
|
||||
if lora_rank > min(input_size, output_size):
|
||||
raise ValueError(f"LoRA rank {lora_rank} must be less or equal than {min(input_size, output_size)}")
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
module._modules[name] = lora.Linear(input_size,
|
||||
output_size,
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
fan_in_fan_out=fan_in_fan_out,
|
||||
merge_weights=merge_weights)
|
||||
|
|
|
@ -115,12 +115,12 @@ class ExperienceMakerPerformanceEvaluator(MakerCallback):
|
|||
avg_send_time_per_sample = (avg_send_duration + 1e-12) / (self.total_samples * self.world_size)
|
||||
|
||||
print_rank_0(
|
||||
'Making Experience Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n' +
|
||||
f'TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n' +
|
||||
f'Sample time (overall): {avg_time_per_sample:.3f} s\n' +
|
||||
f'Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n'
|
||||
+
|
||||
f'Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n'
|
||||
'Making Experience Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n'
|
||||
+ f'TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n'
|
||||
+ f'Sample time (overall): {avg_time_per_sample:.3f} s\n'
|
||||
+ f'Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n'
|
||||
|
||||
+ f'Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n'
|
||||
)
|
||||
|
||||
|
||||
|
@ -204,9 +204,9 @@ class TrainerPerformanceEvaluator(TrainerCallback):
|
|||
avg_update_time_per_sample = (avg_update_duration + 1e-12) / (self.total_samples * self.world_size)
|
||||
|
||||
print_rank_0(
|
||||
'Learning Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n' +
|
||||
f'TFLOPS per GPU: {avg_learn_tflops:.3f}\n' + f'Sample time (overall): {avg_time_per_sample:.3f} s\n' +
|
||||
f'Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n'
|
||||
+
|
||||
f'Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n'
|
||||
'Learning Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n'
|
||||
+ f'TFLOPS per GPU: {avg_learn_tflops:.3f}\n' + f'Sample time (overall): {avg_time_per_sample:.3f} s\n'
|
||||
+ f'Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n'
|
||||
|
||||
+ f'Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n'
|
||||
)
|
||||
|
|
|
@ -6,9 +6,9 @@ from typing import Any, List
|
|||
|
||||
import ray
|
||||
import torch
|
||||
from coati.experience_buffer import ExperienceBuffer
|
||||
from coati.experience_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
|
||||
from coati.experience_maker.base import Experience
|
||||
from coati.replay_buffer import ReplayBuffer
|
||||
from coati.replay_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
|
||||
# from torch.multiprocessing import Queue
|
||||
from ray.util.queue import Queue
|
||||
|
||||
|
|
|
@ -4,8 +4,8 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Union
|
|||
|
||||
import ray
|
||||
import torch
|
||||
from coati.experience_buffer.utils import BufferItem
|
||||
from coati.experience_maker import Experience
|
||||
from coati.replay_buffer.utils import BufferItem
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
|
|
|
@ -8,9 +8,9 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
|||
import ray
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from coati.experience_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
|
||||
from coati.experience_maker import Experience, ExperienceMaker, NaiveExperienceMaker
|
||||
from coati.models.base import Actor, Critic, RewardModel
|
||||
from coati.replay_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
|
||||
from coati.trainer.callbacks import Callback
|
||||
from coati.trainer.strategies import Strategy
|
||||
from coati.trainer.strategies.sampler import DistributedSampler
|
||||
|
@ -19,13 +19,9 @@ from torch import Tensor
|
|||
from tqdm import tqdm
|
||||
|
||||
from .callbacks import ExperienceMakerPerformanceEvaluator, MakerCallback
|
||||
from .utils import (get_model_numel,
|
||||
get_rank,
|
||||
get_world_size,
|
||||
is_rank_0,
|
||||
set_dist_env,
|
||||
state_dict_to)
|
||||
from .lora_constructor import LoRAConstructor
|
||||
from .utils import get_model_numel, get_rank, get_world_size, is_rank_0, set_dist_env, state_dict_to
|
||||
|
||||
|
||||
@ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1})
|
||||
class ExperienceMakerHolder:
|
||||
|
@ -41,7 +37,7 @@ class ExperienceMakerHolder:
|
|||
self,
|
||||
detached_trainer_name_list: List[str],
|
||||
strategy_fn: Callable[[], Strategy],
|
||||
# a function returns (actor, critic, reward_model, initial_model)
|
||||
# a function returns (actor, critic, reward_model, initial_model)
|
||||
model_fn: Callable[[], Tuple[Actor, Critic, RewardModel, Actor]],
|
||||
env_info: Dict[str, str] = None,
|
||||
sync_models_from_trainers: bool = False,
|
||||
|
@ -205,15 +201,19 @@ class ExperienceMakerHolder:
|
|||
self.experience_maker.actor.model.load_state_dict(new_actor_state_dict, strict=False)
|
||||
else:
|
||||
new_actor_state_dict = state_dict_to(new_actor_state_dict, device=torch.cuda.current_device())
|
||||
state_dict_increase = self.actor_lora_constructor.reconstruct_increase(new_actor_state_dict, new_actor_lora_config_dict)
|
||||
self.actor_lora_constructor.load_state_dict_increase(self.experience_maker.actor.model, state_dict_increase)
|
||||
state_dict_increase = self.actor_lora_constructor.reconstruct_increase(
|
||||
new_actor_state_dict, new_actor_lora_config_dict)
|
||||
self.actor_lora_constructor.load_state_dict_increase(
|
||||
self.experience_maker.actor.model, state_dict_increase)
|
||||
if new_critic_state_dict is not None:
|
||||
if not self._update_lora_weights or fully_update:
|
||||
self.experience_maker.critic.load_state_dict(new_critic_state_dict, strict=False)
|
||||
else:
|
||||
new_critic_state_dict = state_dict_to(new_critic_state_dict, device=torch.cuda.current_device())
|
||||
state_dict_increase = self.critic_lora_constructor.reconstruct_increase(new_critic_state_dict, new_critic_lora_config_dict)
|
||||
self.critic_lora_constructor.load_state_dict_increase(self.experience_maker.critic, state_dict_increase)
|
||||
state_dict_increase = self.critic_lora_constructor.reconstruct_increase(
|
||||
new_critic_state_dict, new_critic_lora_config_dict)
|
||||
self.critic_lora_constructor.load_state_dict_increase(
|
||||
self.experience_maker.critic, state_dict_increase)
|
||||
|
||||
# the lock must be released after both actor and critic being updated
|
||||
if chunk_end:
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
from typing import Any, Callable, Dict, List, Optional
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from loralib.layers import LoRALayer
|
||||
from coati.models.lora import LoraLinear
|
||||
from loralib.layers import LoRALayer
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -23,19 +23,19 @@ class LoRAConstructor:
|
|||
Usage:
|
||||
Step 1 (Sender):
|
||||
filter_state_dict_lora()
|
||||
|
||||
|
||||
Step 2 (Sender, Optional):
|
||||
extract_lora_config()
|
||||
|
||||
|
||||
Step 3 (Sender):
|
||||
send state_dict_lora and lora_config_dict
|
||||
|
||||
|
||||
Step 4 (Receiver):
|
||||
reconstruct_increase()
|
||||
|
||||
|
||||
Step 5 (Receiver):
|
||||
load_state_dict_increase()
|
||||
|
||||
|
||||
'''
|
||||
|
||||
def __init__(self):
|
||||
|
|
|
@ -1,4 +0,0 @@
|
|||
from .base import ReplayBuffer
|
||||
from .naive import NaiveReplayBuffer
|
||||
|
||||
__all__ = ['ReplayBuffer', 'NaiveReplayBuffer']
|
|
@ -4,8 +4,8 @@ from typing import List
|
|||
|
||||
import torch.nn as nn
|
||||
import tqdm
|
||||
from coati.experience_buffer import NaiveExperienceBuffer
|
||||
from coati.experience_maker import Experience
|
||||
from coati.replay_buffer import NaiveReplayBuffer
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
@ -62,7 +62,7 @@ class OnPolicyTrainer(ABC):
|
|||
|
||||
Args:
|
||||
strategy (Strategy):the strategy to use for training
|
||||
buffer (NaiveReplayBuffer): the buffer to collect experiences
|
||||
data_buffer (NaiveExperienceBuffer): the buffer to collect experiences
|
||||
sample_buffer (bool, defaults to False): whether to sample from buffer
|
||||
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
|
||||
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
||||
|
@ -70,13 +70,13 @@ class OnPolicyTrainer(ABC):
|
|||
|
||||
def __init__(self,
|
||||
strategy: Strategy,
|
||||
buffer: NaiveReplayBuffer,
|
||||
data_buffer: NaiveExperienceBuffer,
|
||||
sample_buffer: bool,
|
||||
dataloader_pin_memory: bool,
|
||||
callbacks: List[Callback] = []) -> None:
|
||||
super().__init__()
|
||||
self.strategy = strategy
|
||||
self.buffer = buffer
|
||||
self.data_buffer = data_buffer
|
||||
self.sample_buffer = sample_buffer
|
||||
self.dataloader_pin_memory = dataloader_pin_memory
|
||||
self.callbacks = callbacks
|
||||
|
@ -144,7 +144,7 @@ class OnPolicyTrainer(ABC):
|
|||
self._on_make_experience_start()
|
||||
experience = self._make_experience(collect_step)
|
||||
self._on_make_experience_end(experience)
|
||||
self.buffer.append(experience)
|
||||
self.data_buffer.append(experience)
|
||||
|
||||
def _update_phase(self, update_step: int):
|
||||
self._on_learn_epoch_start(update_step)
|
||||
|
@ -181,8 +181,8 @@ class OnPolicyTrainer(ABC):
|
|||
# HACK(cwher): according to the design of boost API, dataloader should also be boosted,
|
||||
# but it is impractical to adapt this pattern in RL training. Thus, I left dataloader unboosted.
|
||||
# I only call strategy.setup_dataloader() to setup dataloader.
|
||||
self.dataloader = self.strategy.setup_dataloader(self.buffer, self.dataloader_pin_memory)
|
||||
self.dataloader = self.strategy.setup_dataloader(self.data_buffer, self.dataloader_pin_memory)
|
||||
for update_step in tqdm.trange(num_update_steps, desc="Update steps", disable=not is_rank_0()):
|
||||
self._update_phase(update_step)
|
||||
# NOTE: this is for on-policy algorithms
|
||||
self.buffer.clear()
|
||||
self.data_buffer.clear()
|
||||
|
|
|
@ -171,13 +171,13 @@ class PerformanceEvaluator(Callback):
|
|||
learn_time_per_sample = divide(avg_learn_duration, num_effective_samples)
|
||||
|
||||
print_rank_0(
|
||||
f'Performance summary:\n' +
|
||||
f'Generate {self.make_experience_num_samples * self.world_size} samples, throughput: {avg_make_experience_throughput:.2f} samples/s, TFLOPS per GPU: {avg_make_experience_tflops:.2f}\n'
|
||||
+
|
||||
f'Train {self.learn_num_samples * self.world_size} samples, throughput: {avg_learn_throughput:.2f} samples/s, TFLOPS per GPU: {avg_learn_tflops:.2f}\n'
|
||||
+ f'Overall throughput: {avg_overall_throughput:.2f} samples/s\n' +
|
||||
f'Overall time per sample: {overall_time_per_sample:.2f} s\n' +
|
||||
f'Make experience time per sample: {make_experience_time_per_sample:.2f} s, {make_experience_time_per_sample/overall_time_per_sample*100:.2f}%\n'
|
||||
+
|
||||
f'Learn time per sample: {learn_time_per_sample:.2f} s, {learn_time_per_sample/overall_time_per_sample*100:.2f}%'
|
||||
f'Performance summary:\n'
|
||||
+ f'Generate {self.make_experience_num_samples * self.world_size} samples, throughput: {avg_make_experience_throughput:.2f} samples/s, TFLOPS per GPU: {avg_make_experience_tflops:.2f}\n'
|
||||
|
||||
+ f'Train {self.learn_num_samples * self.world_size} samples, throughput: {avg_learn_throughput:.2f} samples/s, TFLOPS per GPU: {avg_learn_tflops:.2f}\n'
|
||||
+ f'Overall throughput: {avg_overall_throughput:.2f} samples/s\n'
|
||||
+ f'Overall time per sample: {overall_time_per_sample:.2f} s\n'
|
||||
+ f'Make experience time per sample: {make_experience_time_per_sample:.2f} s, {make_experience_time_per_sample/overall_time_per_sample*100:.2f}%\n'
|
||||
|
||||
+ f'Learn time per sample: {learn_time_per_sample:.2f} s, {learn_time_per_sample/overall_time_per_sample*100:.2f}%'
|
||||
)
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
from typing import Dict, List
|
||||
|
||||
import torch.nn as nn
|
||||
from coati.experience_buffer import NaiveExperienceBuffer
|
||||
from coati.experience_maker import Experience, NaiveExperienceMaker
|
||||
from coati.models.base import Actor, Critic, get_base_model
|
||||
from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss
|
||||
from coati.models.utils import calc_action_log_probs
|
||||
from coati.replay_buffer import NaiveReplayBuffer
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
|
@ -86,9 +86,9 @@ class PPOTrainer(OnPolicyTrainer):
|
|||
assert not offload_inference_models, \
|
||||
"GeminiPlugin is not compatible with manual model.to('cpu')"
|
||||
|
||||
buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
|
||||
data_buffer = NaiveExperienceBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
|
||||
super().__init__(
|
||||
strategy, buffer,
|
||||
strategy, data_buffer,
|
||||
sample_buffer, dataloader_pin_memory,
|
||||
callbacks
|
||||
)
|
||||
|
@ -170,7 +170,7 @@ class PPOTrainer(OnPolicyTrainer):
|
|||
|
||||
# buffer may be empty at first, we should rebuild at each training
|
||||
if self.sample_buffer:
|
||||
experience = self.buffer.sample()
|
||||
experience = self.data_buffer.sample()
|
||||
self._on_learn_batch_start()
|
||||
experience.to_device(self.device)
|
||||
metrics = self._training_step(experience)
|
||||
|
|
|
@ -4,7 +4,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from coati.replay_buffer import ReplayBuffer
|
||||
from coati.experience_buffer import ExperienceBuffer
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||
|
@ -45,7 +45,7 @@ class Strategy(ABC):
|
|||
pass
|
||||
|
||||
@abstractmethod
|
||||
def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
|
||||
def setup_dataloader(self, data_buffer: ExperienceBuffer, pin_memory: bool = False) -> DataLoader:
|
||||
pass
|
||||
|
||||
def model_init_context(self):
|
||||
|
|
|
@ -4,7 +4,6 @@ from typing import Optional
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin
|
||||
|
@ -44,7 +43,7 @@ class LowLevelZeroStrategy(DDPStrategy):
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
stage: int = 3,
|
||||
stage: int = 2,
|
||||
precision: str = 'fp16',
|
||||
seed: int = 42,
|
||||
placement_policy: str = 'cuda',
|
||||
|
@ -214,14 +213,3 @@ class GeminiStrategy(DDPStrategy):
|
|||
ddp_model = model.unwrap()
|
||||
assert isinstance(ddp_model, GeminiDDP)
|
||||
return ddp_model.module
|
||||
|
||||
def save_pretrained(self,
|
||||
model: nn.Module,
|
||||
path: str,
|
||||
only_rank0: bool = True,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
|
||||
raise RuntimeError('ColossalAI strategy with stage-3 does not support save_pretrained() now')
|
||||
|
||||
def get_model_state_dict_shard(self, model: nn.Module, **config):
|
||||
assert isinstance(self.plugin, GeminiPlugin)
|
||||
yield from super().get_model_state_dict_shard(model, **config)
|
||||
|
|
|
@ -7,7 +7,8 @@ import numpy as np
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from coati.replay_buffer import ReplayBuffer
|
||||
from coati.experience_buffer import ExperienceBuffer
|
||||
from coati.models import Actor, Critic, RewardModel
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||
|
@ -71,13 +72,13 @@ class DDPStrategy(Strategy):
|
|||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
|
||||
return self.plugin.prepare_dataloader(replay_buffer,
|
||||
batch_size=replay_buffer.sample_batch_size,
|
||||
def setup_dataloader(self, data_buffer: ExperienceBuffer, pin_memory: bool = False) -> DataLoader:
|
||||
return self.plugin.prepare_dataloader(data_buffer,
|
||||
batch_size=data_buffer.sample_batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
pin_memory=pin_memory,
|
||||
collate_fn=replay_buffer.collate_fn)
|
||||
collate_fn=data_buffer.collate_fn)
|
||||
|
||||
def setup_sampler(self, dataset) -> DistributedSampler:
|
||||
# FIXME(cwher): this is only invoked in train_on_ray, not tested after adapt Boost API.
|
||||
|
@ -92,13 +93,33 @@ class DDPStrategy(Strategy):
|
|||
path: str,
|
||||
only_rank0: bool = True,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
|
||||
if only_rank0 and dist.get_rank() != 0:
|
||||
return
|
||||
unwrapped_model = self.unwrap_model(model)
|
||||
assert isinstance(unwrapped_model, PreTrainedModel)
|
||||
unwrapped_model.save_pretrained(path)
|
||||
if tokenizer is not None:
|
||||
tokenizer.save_pretrained(path)
|
||||
if not only_rank0 or dist.get_rank() == 0:
|
||||
unwrapped_model = self.unwrap_model(model)
|
||||
assert isinstance(unwrapped_model, (Actor, Critic, RewardModel))
|
||||
pretrained_model = unwrapped_model.model
|
||||
assert isinstance(pretrained_model, PreTrainedModel)
|
||||
# HACK: only use hf save_pretrained to save config
|
||||
pretrained_model.save_pretrained(path, save_function=lambda *args, **kwargs: None)
|
||||
if tokenizer is not None:
|
||||
tokenizer.save_pretrained(path)
|
||||
model_path = os.path.join(path, "pytorch_model.bin")
|
||||
self.save_model(model,
|
||||
model_path,
|
||||
only_rank0=only_rank0)
|
||||
|
||||
def _replace_keys(model_path: str,
|
||||
replace_fn: Callable):
|
||||
state_dict = torch.load(model_path, map_location="cpu")
|
||||
state_dict = {
|
||||
replace_fn(k): v
|
||||
for k, v in state_dict.items()
|
||||
}
|
||||
torch.save(state_dict, model_path)
|
||||
|
||||
# FIXME: save_model would add "model." prefix to keys of pytorch_model.bin
|
||||
# HACK: rename keys of pytorch_model.bin
|
||||
if dist.get_rank() == 0:
|
||||
_replace_keys(model_path, lambda k: k.replace("model.", "", 1))
|
||||
|
||||
def get_model_state_dict_shard(self, model: nn.Module, **config):
|
||||
# TODO: implement sharding on naive strategy
|
||||
|
|
|
@ -27,7 +27,6 @@ class DistributedSampler:
|
|||
assert len(indices) == self.num_samples
|
||||
self.indices = indices
|
||||
|
||||
|
||||
def sample(self, batch_size: int) -> list:
|
||||
sampled_indices = np.random.choice(self.indices, batch_size, replace=False)
|
||||
return [self.dataset[idx] for idx in sampled_indices]
|
||||
|
|
|
@ -21,9 +21,13 @@ class CycledDataLoader:
|
|||
self.dataloader = dataloader
|
||||
|
||||
self.count = 0
|
||||
self.dataloader_iter = iter(dataloader)
|
||||
self.dataloader_iter = None
|
||||
|
||||
def next(self):
|
||||
# defer initialization
|
||||
if self.dataloader_iter is None:
|
||||
self.dataloader_iter = iter(self.dataloader)
|
||||
|
||||
self.count += 1
|
||||
try:
|
||||
return next(self.dataloader_iter)
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
import argparse
|
||||
import dataclasses
|
||||
import os
|
||||
import parser
|
||||
from typing import List
|
||||
|
||||
import tqdm
|
||||
from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
|
||||
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
|
||||
from coati.models.opt import OPTRM, OPTActor, OPTCritic
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
from transformers import AutoConfig, AutoTokenizer, BloomConfig, BloomTokenizerFast, GPT2Config, GPT2Tokenizer
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class HFRepoFiles:
|
||||
repo_id: str
|
||||
files: List[str]
|
||||
|
||||
def download(self, dir_path: str):
|
||||
for file in self.files:
|
||||
file_path = hf_hub_download(self.repo_id, file, local_dir=dir_path)
|
||||
|
||||
def download_all(self):
|
||||
file_path = snapshot_download(self.repo_id)
|
||||
|
||||
|
||||
def test_init(model: str, dir_path: str):
|
||||
if model == "gpt2":
|
||||
config = GPT2Config.from_pretrained(dir_path)
|
||||
actor = GPTActor(config=config)
|
||||
critic = GPTCritic(config=config)
|
||||
reward_model = GPTRM(config=config)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(dir_path)
|
||||
elif model == "bloom":
|
||||
config = BloomConfig.from_pretrained(dir_path)
|
||||
actor = BLOOMActor(config=config)
|
||||
critic = BLOOMCritic(config=config)
|
||||
reward_model = BLOOMRM(config=config)
|
||||
tokenizer = BloomTokenizerFast.from_pretrained(dir_path)
|
||||
elif model == "opt":
|
||||
config = AutoConfig.from_pretrained(dir_path)
|
||||
actor = OPTActor(config=config)
|
||||
critic = OPTCritic(config=config)
|
||||
reward_model = OPTRM(config=config)
|
||||
tokenizer = AutoTokenizer.from_pretrained(dir_path)
|
||||
else:
|
||||
raise NotImplementedError(f"Model {model} not implemented")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model-dir", type=str, default="test_models")
|
||||
parser.add_argument("--config-only", default=False, action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
if os.path.exists(args.model_dir):
|
||||
print(f"[INFO]: {args.model_dir} already exists")
|
||||
exit(0)
|
||||
|
||||
repo_list = {
|
||||
"gpt2": HFRepoFiles(
|
||||
repo_id="gpt2",
|
||||
files=["config.json", "tokenizer.json", "vocab.json", "merges.txt"]
|
||||
),
|
||||
"bloom": HFRepoFiles(
|
||||
repo_id="bigscience/bloom-560m",
|
||||
files=["config.json", "tokenizer.json", "tokenizer_config.json"]
|
||||
),
|
||||
"opt": HFRepoFiles(
|
||||
repo_id="facebook/opt-350m",
|
||||
files=["config.json", "tokenizer_config.json", "vocab.json", "merges.txt"]
|
||||
),
|
||||
}
|
||||
|
||||
os.mkdir(args.model_dir)
|
||||
for model_name in tqdm.tqdm(repo_list):
|
||||
dir_path = os.path.join(args.model_dir, model_name)
|
||||
if args.config_only:
|
||||
os.mkdir(dir_path)
|
||||
repo_list[model_name].download(dir_path)
|
||||
else:
|
||||
repo_list[model_name].download_all()
|
||||
test_init(model_name, dir_path)
|
|
@ -1,7 +1,6 @@
|
|||
import argparse
|
||||
|
||||
import random
|
||||
import json
|
||||
import random
|
||||
|
||||
random.seed(42)
|
||||
|
||||
|
@ -10,8 +9,10 @@ def sample(args):
|
|||
with open(args.dataset_path, mode='r') as f:
|
||||
dataset_list = json.load(f)
|
||||
|
||||
sampled_dataset = [{"instruction": sample["instruction"], "id":idx}
|
||||
for idx, sample in enumerate(random.sample(dataset_list, args.sample_size))]
|
||||
sampled_dataset = [
|
||||
{"instruction": sample["instruction"], "id": idx}
|
||||
for idx, sample in enumerate(random.sample(dataset_list, args.sample_size))
|
||||
]
|
||||
|
||||
with open(args.save_path, mode='w') as f:
|
||||
json.dump(sampled_dataset, f, indent=4,
|
||||
|
|
|
@ -4,40 +4,50 @@ import torch
|
|||
from coati.models.bloom import BLOOMActor
|
||||
from coati.models.generation import generate
|
||||
from coati.models.gpt import GPTActor
|
||||
from coati.models.llama import LlamaActor
|
||||
from coati.models.opt import OPTActor
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer
|
||||
|
||||
|
||||
def eval(args):
|
||||
# configure model
|
||||
if args.model == 'gpt2':
|
||||
actor = GPTActor(pretrained=args.pretrain).to(torch.cuda.current_device())
|
||||
actor = GPTActor(pretrained=args.pretrain)
|
||||
elif args.model == 'bloom':
|
||||
actor = BLOOMActor(pretrained=args.pretrain).to(torch.cuda.current_device())
|
||||
actor = BLOOMActor(pretrained=args.pretrain)
|
||||
elif args.model == 'opt':
|
||||
actor = OPTActor(pretrained=args.pretrain).to(torch.cuda.current_device())
|
||||
actor = OPTActor(pretrained=args.pretrain)
|
||||
elif args.model == 'llama':
|
||||
actor = LlamaActor(pretrained=args.pretrain)
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
state_dict = torch.load(args.model_path)
|
||||
actor.load_state_dict(state_dict)
|
||||
actor.to(torch.cuda.current_device())
|
||||
if args.model_path is not None:
|
||||
state_dict = torch.load(args.model_path)
|
||||
actor.load_state_dict(state_dict)
|
||||
|
||||
# configure tokenizer
|
||||
if args.model == 'gpt2':
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'bloom':
|
||||
tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom-560m')
|
||||
tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'opt':
|
||||
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m')
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'llama':
|
||||
tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
tokenizer.eos_token = '<\s>'
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
actor.eval()
|
||||
input = args.input
|
||||
input_ids = tokenizer.encode(input, return_tensors='pt').to(torch.cuda.current_device())
|
||||
input_ids = tokenizer.encode(args.input,
|
||||
return_tensors='pt')\
|
||||
.to(torch.cuda.current_device())
|
||||
outputs = generate(actor,
|
||||
input_ids,
|
||||
max_length=args.max_length,
|
||||
|
@ -45,13 +55,14 @@ def eval(args):
|
|||
top_k=50,
|
||||
top_p=0.95,
|
||||
num_return_sequences=1)
|
||||
output = tokenizer.batch_decode(outputs[0], skip_special_tokens=True)
|
||||
print(output)
|
||||
output = tokenizer.batch_decode(outputs[0],
|
||||
skip_special_tokens=True)
|
||||
print(f"[Output]: {''.join(output)}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
|
||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
|
||||
# We suggest to use the pretrained model from HuggingFace, use pretrain to configure model
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--model_path', type=str, default=None)
|
||||
|
|
|
@ -1,160 +0,0 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
|
||||
tail -n +2 |
|
||||
nl -v 0 |
|
||||
tee /dev/tty |
|
||||
sort -g -k 2 |
|
||||
awk '{print $1}' |
|
||||
head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 4
|
||||
|
||||
set -xue
|
||||
|
||||
if [ -z "$SFT_DATASET" ]; then
|
||||
echo "Please set \$SFT_DATASET to the path to sft dataset."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -z "$PROMPT_PATH" ]; then
|
||||
echo "Please set \$PROMPT_PATH to the path to prompts csv."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -z "$PRETRAIN_DATASET" ]; then
|
||||
echo "Please set \$PRETRAIN_DATASET to the path to alpaca data."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
BASE=$(realpath $(dirname $0))
|
||||
|
||||
export OMP_NUM_THREADS=8
|
||||
|
||||
# install requirements
|
||||
pip install -r ${BASE}/requirements.txt
|
||||
|
||||
wandb init -m offline
|
||||
|
||||
# FIXME: This is a hack to skip tests that are not working
|
||||
# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
|
||||
# - llama-*: These tests can be passed locally, skipped for long execution time
|
||||
SKIPPED_TESTS=(
|
||||
"gpt2-ddp"
|
||||
"llama-ddp"
|
||||
"llama-colossalai_gemini"
|
||||
"llama-colossalai_zero2"
|
||||
)
|
||||
|
||||
# These tests are quick and do not have any dependencies
|
||||
for model in 'gpt2' 'bloom' 'opt' 'llama'; do
|
||||
for strategy in 'ddp' 'colossalai_gemini' 'colossalai_zero2'; do
|
||||
if [[ " ${SKIPPED_TESTS[*]} " =~ " ${model}-${strategy} " ]]; then
|
||||
echo "[Test]: Skipped $model-$strategy"
|
||||
continue
|
||||
fi
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
|
||||
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
|
||||
--strategy $strategy --model $model \
|
||||
--num_episodes 1 --num_collect_steps 2 --num_update_steps 1 \
|
||||
--train_batch_size 2 --lora_rank 4
|
||||
done
|
||||
done
|
||||
|
||||
# train sft
|
||||
torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'bigscience/bloom-560m' \
|
||||
--model 'bloom' --strategy colossalai_zero2 --lora_rank 4 \
|
||||
--dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
|
||||
--save_path ${BASE}/output
|
||||
rm -rf ${BASE}/output
|
||||
|
||||
torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'gpt2' \
|
||||
--model 'gpt2' --strategy colossalai_zero2 \
|
||||
--dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
|
||||
--save_path ${BASE}/output
|
||||
rm -rf ${BASE}/output
|
||||
|
||||
torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'facebook/opt-350m' \
|
||||
--model 'opt' --strategy colossalai_zero2 --lora_rank 4 \
|
||||
--dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
|
||||
--save_path ${BASE}/output
|
||||
rm -rf ${BASE}/output
|
||||
|
||||
torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'gpt2' \
|
||||
--model 'gpt2' --strategy ddp --lora_rank 4 \
|
||||
--dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
|
||||
--save_path ${BASE}/output
|
||||
rm -rf ${BASE}/output
|
||||
|
||||
# train rm
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
||||
--pretrain 'facebook/opt-350m' --model 'opt' \
|
||||
--strategy colossalai_zero2 --loss_fn 'log_sig' \
|
||||
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \
|
||||
--test True --lora_rank 0 \
|
||||
--save_path ${BASE}/rm_ckpt_opt.pt
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
||||
--pretrain 'gpt2' --model 'gpt2' \
|
||||
--strategy colossalai_zero2 --loss_fn 'log_exp' \
|
||||
--dataset 'Dahoas/rm-static' \
|
||||
--test True --lora_rank 0 \
|
||||
--save_path ${BASE}/rm_ckpt_gpt.pt
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
||||
--pretrain 'gpt2' --model 'gpt2' \
|
||||
--strategy ddp --loss_fn 'log_exp' \
|
||||
--dataset 'Dahoas/rm-static' \
|
||||
--test True --lora_rank 4 \
|
||||
--save_path ${BASE}/rm_ckpt.pt
|
||||
rm -rf ${BASE}/rm_ckpt.pt
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
||||
--pretrain 'bigscience/bloom-560m' --model 'bloom' \
|
||||
--strategy colossalai_zero2 --loss_fn 'log_sig' \
|
||||
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \
|
||||
--test True --lora_rank 4 \
|
||||
--save_path ${BASE}/rm_ckpt.pt
|
||||
rm -rf ${BASE}/rm_ckpt.pt
|
||||
|
||||
# train rl
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
|
||||
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
|
||||
--strategy colossalai_zero2 --num_episodes 1 \
|
||||
--num_collect_steps 2 --num_update_steps 1 --train_batch_size 2 \
|
||||
--pretrain 'facebook/opt-350m' --model opt \
|
||||
--rm_pretrain 'facebook/opt-350m' \
|
||||
--rm_path ${BASE}/rm_ckpt_opt.pt \
|
||||
--save_path ${BASE}/actor_checkpoint_prompts.pt
|
||||
rm -rf ${BASE}/rm_ckpt_opt.pt
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
|
||||
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
|
||||
--strategy colossalai_zero2 --num_episodes 1 \
|
||||
--num_collect_steps 2 --num_update_steps 1 --train_batch_size 2 \
|
||||
--pretrain 'gpt2' --model gpt2 \
|
||||
--rm_pretrain 'gpt2' \
|
||||
--rm_path ${BASE}/rm_ckpt_gpt.pt \
|
||||
--save_path ${BASE}/actor_checkpoint_prompts.pt
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
|
||||
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
|
||||
--strategy colossalai_gemini --num_episodes 1 \
|
||||
--num_collect_steps 2 --num_update_steps 1 --train_batch_size 2 \
|
||||
--pretrain 'gpt2' --model gpt2 \
|
||||
--rm_pretrain 'gpt2' \
|
||||
--rm_path ${BASE}/rm_ckpt_gpt.pt \
|
||||
--save_path ${BASE}/actor_checkpoint_prompts.pt
|
||||
rm -rf ${BASE}/rm_ckpt_gpt.pt
|
||||
|
||||
rm -rf ${BASE}/actor_checkpoint_prompts.pt
|
||||
|
||||
# 3080 doesn't support P2P, skip this test
|
||||
# cd ${BASE}/ray && bash test_ci.sh && cd ${BASE}
|
|
@ -1,8 +1,9 @@
|
|||
import argparse
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from coati.dataset import DataCollatorForSupervisedDataset, PromptDataset, SupervisedDataset
|
||||
from coati.dataset import PromptDataset, SupervisedDataset
|
||||
from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
|
||||
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
|
||||
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
|
||||
|
@ -29,6 +30,7 @@ def main(args):
|
|||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||
|
||||
if args.rm_path is not None:
|
||||
warnings.warn('LoRA weights should be merged with the model weights')
|
||||
state_dict = torch.load(args.rm_path, map_location='cpu')
|
||||
|
||||
with strategy.model_init_context():
|
||||
|
@ -50,18 +52,18 @@ def main(args):
|
|||
rm_model_name = args.rm_model
|
||||
|
||||
if rm_model_name == 'gpt2':
|
||||
reward_model = GPTRM(pretrained=args.rm_pretrain)
|
||||
reward_model = GPTRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
|
||||
elif rm_model_name == 'bloom':
|
||||
reward_model = BLOOMRM(pretrained=args.rm_pretrain)
|
||||
reward_model = BLOOMRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
|
||||
elif rm_model_name == 'opt':
|
||||
reward_model = OPTRM(pretrained=args.rm_pretrain)
|
||||
reward_model = OPTRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
|
||||
elif rm_model_name == 'llama':
|
||||
reward_model = LlamaRM(pretrained=args.rm_pretrain)
|
||||
reward_model = LlamaRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
|
||||
else:
|
||||
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
|
||||
|
||||
if args.rm_path is not None:
|
||||
reward_model.load_state_dict(state_dict)
|
||||
reward_model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
initial_model.to(torch.float16).to(torch.cuda.current_device())
|
||||
reward_model.to(torch.float16).to(torch.cuda.current_device())
|
||||
|
@ -89,7 +91,7 @@ def main(args):
|
|||
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
|
||||
|
||||
if args.rm_path is not None:
|
||||
critic.load_state_dict(state_dict)
|
||||
critic.load_state_dict(state_dict, strict=False)
|
||||
del state_dict
|
||||
|
||||
if args.strategy != 'colossalai_gemini':
|
||||
|
@ -106,23 +108,25 @@ def main(args):
|
|||
|
||||
# configure tokenizer
|
||||
if args.model == 'gpt2':
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(
|
||||
'gpt2' if args.tokenizer is None else args.tokenizer)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'bloom':
|
||||
tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
|
||||
tokenizer = BloomTokenizerFast.from_pretrained(
|
||||
'bigscience/bloom-560m' if args.tokenizer is None else args.tokenizer)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'opt':
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'llama':
|
||||
tokenizer = LlamaTokenizer.from_pretrained(args.pretrain)
|
||||
tokenizer = LlamaTokenizer.from_pretrained(
|
||||
"hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer)
|
||||
tokenizer.eos_token = '<\s>'
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
||||
|
||||
prompt_dataset = PromptDataset(tokenizer=tokenizer, data_path=args.prompt_dataset, max_datasets_size=16384)
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)
|
||||
|
@ -144,8 +148,7 @@ def main(args):
|
|||
pretrain_dataloader = DataLoader(pretrain_dataset,
|
||||
shuffle=(pretrain_sampler is None),
|
||||
sampler=pretrain_sampler,
|
||||
batch_size=args.ptx_batch_size,
|
||||
collate_fn=data_collator)
|
||||
batch_size=args.ptx_batch_size)
|
||||
|
||||
# NOTE: For small models like opt-1.3b, reward model and initial model are not required to be parallelized.
|
||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = \
|
||||
|
@ -197,6 +200,7 @@ if __name__ == '__main__':
|
|||
default='colossalai_zero2',
|
||||
help='strategy to use')
|
||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
|
||||
parser.add_argument('--tokenizer', type=str, default=None)
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--rm_model', default=None, choices=['gpt2', 'bloom', 'opt', 'llama'])
|
||||
parser.add_argument('--rm_path', type=str, default=None)
|
||||
|
|
|
@ -36,34 +36,39 @@ def train(args):
|
|||
# configure model
|
||||
with strategy.model_init_context():
|
||||
if args.model == 'bloom':
|
||||
model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||
model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
|
||||
elif args.model == 'opt':
|
||||
model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||
model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
|
||||
elif args.model == 'gpt2':
|
||||
model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||
model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
|
||||
elif args.model == 'llama':
|
||||
model = LlamaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||
model = LlamaRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
model.to(torch.float16).to(torch.cuda.current_device())
|
||||
|
||||
if args.model_path is not None:
|
||||
state_dict = torch.load(args.model_path)
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
model = model.to(torch.float16)
|
||||
|
||||
# configure tokenizer
|
||||
if args.model == 'gpt2':
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(
|
||||
'gpt2' if args.tokenizer is None else args.tokenizer)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'bloom':
|
||||
tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
|
||||
tokenizer = BloomTokenizerFast.from_pretrained(
|
||||
'bigscience/bloom-560m' if args.tokenizer is None else args.tokenizer)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'opt':
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'llama':
|
||||
tokenizer = LlamaTokenizer.from_pretrained(args.pretrain)
|
||||
tokenizer = LlamaTokenizer.from_pretrained(
|
||||
"hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer)
|
||||
tokenizer.eos_token = '<\s>'
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
@ -89,8 +94,8 @@ def train(args):
|
|||
data = load_dataset(args.dataset)
|
||||
|
||||
if args.test:
|
||||
train_data = data['train'].select(range(100))
|
||||
eval_data = data['test'].select(range(10))
|
||||
train_data = data['train'].select(range(20))
|
||||
eval_data = data['test'].select(range(5))
|
||||
else:
|
||||
train_data = data['train']
|
||||
eval_data = data['test']
|
||||
|
@ -177,6 +182,7 @@ if __name__ == '__main__':
|
|||
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||
default='colossalai_zero2')
|
||||
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom')
|
||||
parser.add_argument('--tokenizer', type=str, default=None)
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--model_path', type=str, default=None)
|
||||
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
||||
|
@ -184,7 +190,7 @@ if __name__ == '__main__':
|
|||
type=str,
|
||||
choices=['Anthropic/hh-rlhf', 'Dahoas/rm-static'],
|
||||
default='Dahoas/rm-static')
|
||||
parser.add_argument('--subset', type=str, default=None)
|
||||
parser.add_argument('--subset', type=lambda x: None if x == 'None' else x, default=None)
|
||||
parser.add_argument('--save_path', type=str, default='rm_ckpt')
|
||||
parser.add_argument('--max_epochs', type=int, default=1)
|
||||
parser.add_argument('--batch_size', type=int, default=1)
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
|
||||
| tail -n +2 \
|
||||
| nl -v 0 \
|
||||
| tee /dev/tty \
|
||||
| sort -g -k 2 \
|
||||
| awk '{print $1}' \
|
||||
| head -n $n)
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
|
||||
tail -n +2 |
|
||||
nl -v 0 |
|
||||
tee /dev/tty |
|
||||
sort -g -k 2 |
|
||||
awk '{print $1}' |
|
||||
head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
|
@ -16,9 +16,7 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
|||
set_n_least_used_CUDA_VISIBLE_DEVICES 2
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 train_reward_model.py \
|
||||
--pretrain <your pretrain path> \
|
||||
--model 'bloom' \
|
||||
--strategy colossalai_zero2 \
|
||||
--loss_fn 'log_sig'\
|
||||
--save_path <your model saving path>\
|
||||
--dataset 'Anthropic/hh-rlhf'\
|
||||
--model 'bloom' \
|
||||
--strategy colossalai_zero2 \
|
||||
--loss_fn 'log_sig' \
|
||||
--dataset 'Anthropic/hh-rlhf'
|
||||
|
|
|
@ -1,24 +1,22 @@
|
|||
import argparse
|
||||
import math
|
||||
import os
|
||||
import warnings
|
||||
|
||||
import loralib as lora
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from coati.dataset import DataCollatorForSupervisedDataset, SFTDataset, SupervisedDataset
|
||||
from coati.models import convert_to_lora_module
|
||||
from coati.dataset import SFTDataset, SupervisedDataset
|
||||
from coati.models.bloom import BLOOMActor
|
||||
from coati.models.gpt import GPTActor
|
||||
from coati.models.llama import LlamaActor
|
||||
from coati.models.opt import OPTActor
|
||||
from coati.trainer import SFTTrainer
|
||||
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
|
||||
from datasets import load_dataset
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from transformers import AutoTokenizer, BloomConfig, BloomForCausalLM, BloomTokenizerFast, LlamaConfig, LlamaForCausalLM
|
||||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer
|
||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||
from transformers.models.opt.configuration_opt import OPTConfig
|
||||
from transformers.models.opt.modeling_opt import OPTForCausalLM
|
||||
from transformers.trainer import get_scheduler
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
@ -31,8 +29,6 @@ def train(args):
|
|||
if args.strategy == 'ddp':
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == 'colossalai_gemini':
|
||||
raise NotImplementedError(
|
||||
'Gemini is not supported .from_pretrained() yet. We will update this after checkpoint io is ready.')
|
||||
strategy = GeminiStrategy(placement_policy='cuda')
|
||||
elif args.strategy == 'colossalai_zero2':
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
|
||||
|
@ -42,40 +38,49 @@ def train(args):
|
|||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||
|
||||
# configure model
|
||||
if args.lora_rank > 0:
|
||||
warnings.warn("Gradient checkpoint is disabled when using LoRA")
|
||||
args.grad_checkpoint = False
|
||||
with strategy.model_init_context():
|
||||
if args.model == 'bloom':
|
||||
model = convert_to_lora_module(BloomForCausalLM.from_pretrained(args.pretrain),
|
||||
args.lora_rank).half().cuda()
|
||||
model = BLOOMActor(pretrained=args.pretrain,
|
||||
lora_rank=args.lora_rank,
|
||||
checkpoint=args.grad_checkpoint)
|
||||
elif args.model == 'opt':
|
||||
model = convert_to_lora_module(OPTForCausalLM.from_pretrained(args.pretrain), args.lora_rank).half().cuda()
|
||||
model = OPTActor(pretrained=args.pretrain,
|
||||
lora_rank=args.lora_rank,
|
||||
checkpoint=args.grad_checkpoint)
|
||||
elif args.model == 'gpt2':
|
||||
model = convert_to_lora_module(GPT2LMHeadModel.from_pretrained(args.pretrain), args.lora_rank).half().cuda()
|
||||
model = GPTActor(pretrained=args.pretrain,
|
||||
lora_rank=args.lora_rank,
|
||||
checkpoint=args.grad_checkpoint)
|
||||
elif args.model == 'llama':
|
||||
model = convert_to_lora_module(LlamaForCausalLM.from_pretrained(args.pretrain),
|
||||
args.lora_rank).half().cuda()
|
||||
model = LlamaActor(pretrained=args.pretrain,
|
||||
lora_rank=args.lora_rank,
|
||||
checkpoint=args.grad_checkpoint)
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
if args.grad_checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
model.to(torch.float16).to(torch.cuda.current_device())
|
||||
|
||||
# configure tokenizer
|
||||
if args.model == 'gpt2':
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(
|
||||
'gpt2' if args.tokenizer is None else args.tokenizer)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'bloom':
|
||||
tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
|
||||
tokenizer = BloomTokenizerFast.from_pretrained(
|
||||
'bigscience/bloom-560m' if args.tokenizer is None else args.tokenizer)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'opt':
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'llama':
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.pretrain,
|
||||
padding_side="right",
|
||||
use_fast=False,
|
||||
)
|
||||
tokenizer.eos_token = '</s>'
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer = LlamaTokenizer.from_pretrained(
|
||||
"hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer)
|
||||
tokenizer.eos_token = '<\s>'
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
|
@ -111,7 +116,6 @@ def train(args):
|
|||
max_datasets_size=args.max_datasets_size,
|
||||
max_length=args.max_len)
|
||||
eval_dataset = None
|
||||
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
||||
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
train_sampler = DistributedSampler(train_dataset,
|
||||
|
@ -135,14 +139,12 @@ def train(args):
|
|||
shuffle=(train_sampler is None),
|
||||
sampler=train_sampler,
|
||||
batch_size=args.batch_size,
|
||||
collate_fn=data_collator,
|
||||
pin_memory=True)
|
||||
if eval_dataset is not None:
|
||||
eval_dataloader = DataLoader(eval_dataset,
|
||||
shuffle=(eval_sampler is None),
|
||||
sampler=eval_sampler,
|
||||
batch_size=args.batch_size,
|
||||
collate_fn=data_collator,
|
||||
pin_memory=True)
|
||||
else:
|
||||
eval_dataloader = None
|
||||
|
@ -184,6 +186,7 @@ if __name__ == '__main__':
|
|||
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_zero2_cpu'],
|
||||
default='colossalai_zero2')
|
||||
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom')
|
||||
parser.add_argument('--tokenizer', type=str, default=None)
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--dataset', type=str, default=None)
|
||||
parser.add_argument('--max_datasets_size', type=int, default=None)
|
||||
|
|
|
@ -1,12 +1,29 @@
|
|||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
|
||||
tail -n +2 |
|
||||
nl -v 0 |
|
||||
tee /dev/tty |
|
||||
sort -g -k 2 |
|
||||
awk '{print $1}' |
|
||||
head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 4
|
||||
|
||||
torchrun --standalone --nproc_per_node=4 train_sft.py \
|
||||
--pretrain "/path/to/LLaMa-7B/" \
|
||||
--model 'llama' \
|
||||
--strategy colossalai_zero2 \
|
||||
--log_interval 10 \
|
||||
--save_path /path/to/Coati-7B \
|
||||
--save_path /path/to/Coati-7B \
|
||||
--dataset /path/to/data.json \
|
||||
--batch_size 4 \
|
||||
--accumulation_steps 8 \
|
||||
--lr 2e-5 \
|
||||
--max_datasets_size 512 \
|
||||
--max_epochs 1 \
|
||||
--max_epochs 1
|
||||
|
|
|
@ -4,8 +4,8 @@ import argparse
|
|||
from time import time
|
||||
|
||||
import torch
|
||||
from llama_gptq import load_quant
|
||||
from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM
|
||||
from coati.quant import llama_load_quant, low_resource_init
|
||||
from transformers import AutoTokenizer, GenerationConfig, LlamaConfig, LlamaForCausalLM
|
||||
|
||||
|
||||
def generate_prompt(instruction, input=None):
|
||||
|
@ -106,7 +106,10 @@ if __name__ == "__main__":
|
|||
tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
|
||||
|
||||
if args.quant == '4bit':
|
||||
model = load_quant(args.pretrained, args.gptq_checkpoint, 4, args.gptq_group_size)
|
||||
with low_resource_init():
|
||||
config = LlamaConfig.from_pretrained(args.pretrained)
|
||||
model = LlamaForCausalLM(config)
|
||||
model = llama_load_quant(model, args.gptq_checkpoint, 4, args.gptq_group_size)
|
||||
model.cuda()
|
||||
else:
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
|
|
|
@ -1,5 +0,0 @@
|
|||
from .loader import load_quant
|
||||
|
||||
__all__ = [
|
||||
'load_quant',
|
||||
]
|
|
@ -1,41 +0,0 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
from transformers import LlamaConfig, LlamaForCausalLM
|
||||
|
||||
from .model_utils import find_layers
|
||||
from .quant import make_quant
|
||||
|
||||
|
||||
def load_quant(pretrained: str, checkpoint: str, wbits: int, groupsize: int):
|
||||
config = LlamaConfig.from_pretrained(pretrained)
|
||||
|
||||
def noop(*args, **kwargs):
|
||||
pass
|
||||
|
||||
torch.nn.init.kaiming_uniform_ = noop
|
||||
torch.nn.init.uniform_ = noop
|
||||
torch.nn.init.normal_ = noop
|
||||
|
||||
torch.set_default_dtype(torch.half)
|
||||
transformers.modeling_utils._init_weights = False
|
||||
torch.set_default_dtype(torch.half)
|
||||
model = LlamaForCausalLM(config)
|
||||
torch.set_default_dtype(torch.float)
|
||||
model = model.eval()
|
||||
layers = find_layers(model)
|
||||
for name in ['lm_head']:
|
||||
if name in layers:
|
||||
del layers[name]
|
||||
make_quant(model, layers, wbits, groupsize)
|
||||
|
||||
print(f'Loading model with {wbits} bits...')
|
||||
if checkpoint.endswith('.safetensors'):
|
||||
from safetensors.torch import load_file as safe_load
|
||||
model.load_state_dict(safe_load(checkpoint))
|
||||
else:
|
||||
model.load_state_dict(torch.load(checkpoint))
|
||||
model.seqlen = 2048
|
||||
print('Done.')
|
||||
|
||||
return model
|
|
@ -1,13 +0,0 @@
|
|||
# copied from https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/past/modelutils.py
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
|
||||
if type(module) in layers:
|
||||
return {name: module}
|
||||
res = {}
|
||||
for name1, child in module.named_children():
|
||||
res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1))
|
||||
return res
|
|
@ -1,283 +0,0 @@
|
|||
# copied from https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/past/quant.py
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def quantize(x, scale, zero, maxq):
|
||||
q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
|
||||
return scale * (q - zero)
|
||||
|
||||
|
||||
class Quantizer(nn.Module):
|
||||
|
||||
def __init__(self, shape=1):
|
||||
super(Quantizer, self).__init__()
|
||||
self.register_buffer('maxq', torch.tensor(0))
|
||||
self.register_buffer('scale', torch.zeros(shape))
|
||||
self.register_buffer('zero', torch.zeros(shape))
|
||||
|
||||
def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=.8):
|
||||
self.maxq = torch.tensor(2**bits - 1)
|
||||
self.perchannel = perchannel
|
||||
self.sym = sym
|
||||
self.mse = mse
|
||||
self.norm = norm
|
||||
self.grid = grid
|
||||
self.maxshrink = maxshrink
|
||||
|
||||
def find_params(self, x, weight=False):
|
||||
dev = x.device
|
||||
self.maxq = self.maxq.to(dev)
|
||||
|
||||
shape = x.shape
|
||||
if self.perchannel:
|
||||
if weight:
|
||||
x = x.flatten(1)
|
||||
else:
|
||||
if len(shape) == 4:
|
||||
x = x.permute([1, 0, 2, 3])
|
||||
x = x.flatten(1)
|
||||
if len(shape) == 3:
|
||||
x = x.reshape((-1, shape[-1])).t()
|
||||
if len(shape) == 2:
|
||||
x = x.t()
|
||||
else:
|
||||
x = x.flatten().unsqueeze(0)
|
||||
|
||||
tmp = torch.zeros(x.shape[0], device=dev)
|
||||
xmin = torch.minimum(x.min(1)[0], tmp)
|
||||
xmax = torch.maximum(x.max(1)[0], tmp)
|
||||
|
||||
if self.sym:
|
||||
xmax = torch.maximum(torch.abs(xmin), xmax)
|
||||
tmp = xmin < 0
|
||||
if torch.any(tmp):
|
||||
xmin[tmp] = -xmax[tmp]
|
||||
tmp = (xmin == 0) & (xmax == 0)
|
||||
xmin[tmp] = -1
|
||||
xmax[tmp] = +1
|
||||
|
||||
self.scale = (xmax - xmin) / self.maxq
|
||||
if self.sym:
|
||||
self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
|
||||
else:
|
||||
self.zero = torch.round(-xmin / self.scale)
|
||||
|
||||
if self.mse:
|
||||
best = torch.full([x.shape[0]], float('inf'), device=dev)
|
||||
for i in range(int(self.maxshrink * self.grid)):
|
||||
p = 1 - i / self.grid
|
||||
xmin1 = p * xmin
|
||||
xmax1 = p * xmax
|
||||
scale1 = (xmax1 - xmin1) / self.maxq
|
||||
zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
|
||||
q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq)
|
||||
q -= x
|
||||
q.abs_()
|
||||
q.pow_(self.norm)
|
||||
err = torch.sum(q, 1)
|
||||
tmp = err < best
|
||||
if torch.any(tmp):
|
||||
best[tmp] = err[tmp]
|
||||
self.scale[tmp] = scale1[tmp]
|
||||
self.zero[tmp] = zero1[tmp]
|
||||
if not self.perchannel:
|
||||
if weight:
|
||||
tmp = shape[0]
|
||||
else:
|
||||
tmp = shape[1] if len(shape) != 3 else shape[2]
|
||||
self.scale = self.scale.repeat(tmp)
|
||||
self.zero = self.zero.repeat(tmp)
|
||||
|
||||
if weight:
|
||||
shape = [-1] + [1] * (len(shape) - 1)
|
||||
self.scale = self.scale.reshape(shape)
|
||||
self.zero = self.zero.reshape(shape)
|
||||
return
|
||||
if len(shape) == 4:
|
||||
self.scale = self.scale.reshape((1, -1, 1, 1))
|
||||
self.zero = self.zero.reshape((1, -1, 1, 1))
|
||||
if len(shape) == 3:
|
||||
self.scale = self.scale.reshape((1, 1, -1))
|
||||
self.zero = self.zero.reshape((1, 1, -1))
|
||||
if len(shape) == 2:
|
||||
self.scale = self.scale.unsqueeze(0)
|
||||
self.zero = self.zero.unsqueeze(0)
|
||||
|
||||
def quantize(self, x):
|
||||
if self.ready():
|
||||
return quantize(x, self.scale, self.zero, self.maxq)
|
||||
return x
|
||||
|
||||
def enabled(self):
|
||||
return self.maxq > 0
|
||||
|
||||
def ready(self):
|
||||
return torch.all(self.scale != 0)
|
||||
|
||||
|
||||
try:
|
||||
import quant_cuda
|
||||
except:
|
||||
print('CUDA extension not installed.')
|
||||
|
||||
# Assumes layer is perfectly divisible into 256 * 256 blocks
|
||||
|
||||
|
||||
class QuantLinear(nn.Module):
|
||||
|
||||
def __init__(self, bits, groupsize, infeatures, outfeatures):
|
||||
super().__init__()
|
||||
if bits not in [2, 3, 4, 8]:
|
||||
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
|
||||
self.infeatures = infeatures
|
||||
self.outfeatures = outfeatures
|
||||
self.bits = bits
|
||||
if groupsize != -1 and groupsize < 32 and groupsize != int(math.pow(2, int(math.log2(groupsize)))):
|
||||
raise NotImplementedError("groupsize supports powers of 2 greater than 32. (e.g. : 32,64,128,etc)")
|
||||
groupsize = groupsize if groupsize != -1 else infeatures
|
||||
self.groupsize = groupsize
|
||||
self.register_buffer(
|
||||
'qzeros', torch.zeros((math.ceil(infeatures / groupsize), outfeatures // 256 * (bits * 8)),
|
||||
dtype=torch.int))
|
||||
self.register_buffer('scales', torch.zeros((math.ceil(infeatures / groupsize), outfeatures)))
|
||||
self.register_buffer('bias', torch.zeros(outfeatures))
|
||||
self.register_buffer('qweight', torch.zeros((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int))
|
||||
self._initialized_quant_state = False
|
||||
|
||||
def pack(self, linear, scales, zeros):
|
||||
scales = scales.t().contiguous()
|
||||
zeros = zeros.t().contiguous()
|
||||
scale_zeros = zeros * scales
|
||||
self.scales = scales.clone()
|
||||
if linear.bias is not None:
|
||||
self.bias = linear.bias.clone()
|
||||
|
||||
intweight = []
|
||||
for idx in range(self.infeatures):
|
||||
g_idx = idx // self.groupsize
|
||||
intweight.append(
|
||||
torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx]) / self.scales[g_idx]).to(torch.int)[:,
|
||||
None])
|
||||
intweight = torch.cat(intweight, dim=1)
|
||||
intweight = intweight.t().contiguous()
|
||||
intweight = intweight.numpy().astype(np.uint32)
|
||||
qweight = np.zeros((intweight.shape[0] // 256 * (self.bits * 8), intweight.shape[1]), dtype=np.uint32)
|
||||
i = 0
|
||||
row = 0
|
||||
while row < qweight.shape[0]:
|
||||
if self.bits in [2, 4, 8]:
|
||||
for j in range(i, i + (32 // self.bits)):
|
||||
qweight[row] |= intweight[j] << (self.bits * (j - i))
|
||||
i += 32 // self.bits
|
||||
row += 1
|
||||
elif self.bits == 3:
|
||||
for j in range(i, i + 10):
|
||||
qweight[row] |= intweight[j] << (3 * (j - i))
|
||||
i += 10
|
||||
qweight[row] |= intweight[i] << 30
|
||||
row += 1
|
||||
qweight[row] |= (intweight[i] >> 2) & 1
|
||||
i += 1
|
||||
for j in range(i, i + 10):
|
||||
qweight[row] |= intweight[j] << (3 * (j - i) + 1)
|
||||
i += 10
|
||||
qweight[row] |= intweight[i] << 31
|
||||
row += 1
|
||||
qweight[row] |= (intweight[i] >> 1) & 0x3
|
||||
i += 1
|
||||
for j in range(i, i + 10):
|
||||
qweight[row] |= intweight[j] << (3 * (j - i) + 2)
|
||||
i += 10
|
||||
row += 1
|
||||
else:
|
||||
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
|
||||
|
||||
qweight = qweight.astype(np.int32)
|
||||
self.qweight = torch.from_numpy(qweight)
|
||||
|
||||
zeros -= 1
|
||||
zeros = zeros.numpy().astype(np.uint32)
|
||||
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 256 * (self.bits * 8)), dtype=np.uint32)
|
||||
i = 0
|
||||
col = 0
|
||||
while col < qzeros.shape[1]:
|
||||
if self.bits in [2, 4, 8]:
|
||||
for j in range(i, i + (32 // self.bits)):
|
||||
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
|
||||
i += 32 // self.bits
|
||||
col += 1
|
||||
elif self.bits == 3:
|
||||
for j in range(i, i + 10):
|
||||
qzeros[:, col] |= zeros[:, j] << (3 * (j - i))
|
||||
i += 10
|
||||
qzeros[:, col] |= zeros[:, i] << 30
|
||||
col += 1
|
||||
qzeros[:, col] |= (zeros[:, i] >> 2) & 1
|
||||
i += 1
|
||||
for j in range(i, i + 10):
|
||||
qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 1)
|
||||
i += 10
|
||||
qzeros[:, col] |= zeros[:, i] << 31
|
||||
col += 1
|
||||
qzeros[:, col] |= (zeros[:, i] >> 1) & 0x3
|
||||
i += 1
|
||||
for j in range(i, i + 10):
|
||||
qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 2)
|
||||
i += 10
|
||||
col += 1
|
||||
else:
|
||||
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
|
||||
|
||||
qzeros = qzeros.astype(np.int32)
|
||||
self.qzeros = torch.from_numpy(qzeros)
|
||||
|
||||
def forward(self, x):
|
||||
intermediate_dtype = torch.float32
|
||||
|
||||
if not self._initialized_quant_state:
|
||||
# Do we even have a bias? Check for at least one non-zero element.
|
||||
if self.bias is not None and bool(torch.any(self.bias != 0)):
|
||||
# Then make sure it's the right type.
|
||||
self.bias.data = self.bias.data.to(intermediate_dtype)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
outshape = list(x.shape)
|
||||
outshape[-1] = self.outfeatures
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
if self.bias is None:
|
||||
y = torch.zeros(x.shape[0], outshape[-1], dtype=intermediate_dtype, device=x.device)
|
||||
else:
|
||||
y = self.bias.clone().repeat(x.shape[0], 1)
|
||||
|
||||
output_dtype = x.dtype
|
||||
x = x.to(intermediate_dtype)
|
||||
if self.bits == 2:
|
||||
quant_cuda.vecquant2matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize)
|
||||
elif self.bits == 3:
|
||||
quant_cuda.vecquant3matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize)
|
||||
elif self.bits == 4:
|
||||
quant_cuda.vecquant4matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize)
|
||||
elif self.bits == 8:
|
||||
quant_cuda.vecquant8matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize)
|
||||
else:
|
||||
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
|
||||
y = y.to(output_dtype)
|
||||
return y.reshape(outshape)
|
||||
|
||||
|
||||
def make_quant(module, names, bits, groupsize, name=''):
|
||||
if isinstance(module, QuantLinear):
|
||||
return
|
||||
for attr in dir(module):
|
||||
tmp = getattr(module, attr)
|
||||
name1 = name + '.' + attr if name != '' else attr
|
||||
if name1 in names:
|
||||
setattr(module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features))
|
||||
for name1, child in module.named_children():
|
||||
make_quant(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1)
|
|
@ -5,8 +5,7 @@ from locust import HttpUser, task
|
|||
samples = [[
|
||||
dict(
|
||||
instruction='Who is the best player in the history of NBA?',
|
||||
response=
|
||||
'The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
|
||||
response='The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
|
||||
),
|
||||
dict(instruction='continue this talk', response=''),
|
||||
], [
|
||||
|
|
|
@ -1,19 +1,19 @@
|
|||
import argparse
|
||||
import os
|
||||
from threading import Lock
|
||||
from typing import Dict, Generator, List, Optional
|
||||
from typing import Generator, List, Optional
|
||||
|
||||
import torch
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from coati.quant import llama_load_quant, low_resource_init
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from llama_gptq import load_quant
|
||||
from pydantic import BaseModel, Field
|
||||
from slowapi import Limiter, _rate_limit_exceeded_handler
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
from slowapi.util import get_remote_address
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM
|
||||
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
|
||||
from utils import ChatPromptProcessor, Dialogue, LockedIterator, load_json, sample_streamingly, update_model_kwargs_fn
|
||||
|
||||
CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.'
|
||||
|
@ -56,7 +56,7 @@ app.add_middleware(
|
|||
|
||||
def generate_streamingly(prompt, max_new_tokens, top_k, top_p, temperature):
|
||||
inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()}
|
||||
#TODO(ver217): streaming generation does not support repetition_penalty now
|
||||
# TODO(ver217): streaming generation does not support repetition_penalty now
|
||||
model_kwargs = {
|
||||
'max_generate_tokens': max_new_tokens,
|
||||
'early_stopping': True,
|
||||
|
@ -162,7 +162,10 @@ if __name__ == '__main__':
|
|||
prompt_processor = ChatPromptProcessor(tokenizer, CONTEXT, MAX_LEN, censored_words=censored_words)
|
||||
|
||||
if args.quant == '4bit':
|
||||
model = load_quant(args.pretrained, args.gptq_checkpoint, 4, args.gptq_group_size)
|
||||
with low_resource_init():
|
||||
config = LlamaConfig.from_pretrained(args.pretrained)
|
||||
model = LlamaForCausalLM(config)
|
||||
model = llama_load_quant(model, args.gptq_checkpoint, 4, args.gptq_group_size)
|
||||
model.cuda()
|
||||
else:
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
|
|
|
@ -10,37 +10,34 @@ samples = [
|
|||
([
|
||||
Dialogue(
|
||||
instruction='Who is the best player in the history of NBA?',
|
||||
response=
|
||||
'The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
|
||||
response='The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
|
||||
),
|
||||
Dialogue(instruction='continue this talk', response=''),
|
||||
], 128,
|
||||
'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\nThe best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1\n\n### Instruction:\ncontinue this talk\n\n### Response:\n'
|
||||
'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\nThe best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1\n\n### Instruction:\ncontinue this talk\n\n### Response:\n'
|
||||
),
|
||||
([
|
||||
Dialogue(
|
||||
instruction='Who is the best player in the history of NBA?',
|
||||
response=
|
||||
'The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
|
||||
response='The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
|
||||
),
|
||||
Dialogue(instruction='continue this talk', response=''),
|
||||
], 200,
|
||||
'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this talk\n\n### Response:\n'
|
||||
'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this talk\n\n### Response:\n'
|
||||
),
|
||||
([
|
||||
Dialogue(
|
||||
instruction='Who is the best player in the history of NBA?',
|
||||
response=
|
||||
'The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
|
||||
response='The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
|
||||
),
|
||||
Dialogue(instruction='continue this talk', response=''),
|
||||
], 211,
|
||||
'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this\n\n### Response:\n'
|
||||
'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this\n\n### Response:\n'
|
||||
),
|
||||
([
|
||||
Dialogue(instruction='Who is the best player in the history of NBA?', response=''),
|
||||
], 128,
|
||||
'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\n'
|
||||
'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\n'
|
||||
),
|
||||
]
|
||||
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
import json
|
||||
import re
|
||||
from threading import Lock
|
||||
from typing import Any, Callable, Generator, List, Optional
|
||||
import json
|
||||
import jieba
|
||||
|
||||
import jieba
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
@ -127,7 +127,7 @@ STOP_PAT = re.compile(r'(###|instruction:).*', flags=(re.I | re.S))
|
|||
class ChatPromptProcessor:
|
||||
SAFE_RESPONSE = 'The input/response contains inappropriate content, please rephrase your prompt.'
|
||||
|
||||
def __init__(self, tokenizer, context: str, max_len: int = 2048, censored_words: List[str]=[]):
|
||||
def __init__(self, tokenizer, context: str, max_len: int = 2048, censored_words: List[str] = []):
|
||||
self.tokenizer = tokenizer
|
||||
self.context = context
|
||||
self.max_len = max_len
|
||||
|
@ -182,6 +182,7 @@ class ChatPromptProcessor:
|
|||
intersection = set(jieba.cut(text.lower())) & self.censored_words
|
||||
return len(intersection) > 0
|
||||
|
||||
|
||||
class LockedIterator:
|
||||
|
||||
def __init__(self, it, lock: Lock) -> None:
|
||||
|
@ -195,6 +196,7 @@ class LockedIterator:
|
|||
with self.lock:
|
||||
return next(self.it)
|
||||
|
||||
|
||||
def load_json(path: str):
|
||||
with open(path) as f:
|
||||
return json.load(f)
|
||||
return json.load(f)
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
#!/bin/bash
|
||||
|
||||
set -xue
|
||||
|
||||
echo "Hint: You can run this script with 'verbose' as the first argument to run all strategies."
|
||||
|
||||
if [[ $# -ne 0 && "$1" == "verbose" ]]; then
|
||||
STRATEGIES=(
|
||||
'ddp'
|
||||
'colossalai_gemini'
|
||||
'colossalai_gemini_cpu'
|
||||
'colossalai_zero2'
|
||||
'colossalai_zero2_cpu'
|
||||
'colossalai_zero1'
|
||||
'colossalai_zero1_cpu'
|
||||
)
|
||||
else
|
||||
STRATEGIES=(
|
||||
'colossalai_zero2'
|
||||
)
|
||||
fi
|
||||
|
||||
BASE_DIR=$(dirname $(dirname $(realpath $BASH_SOURCE)))
|
||||
BENCHMARKS_DIR=$BASE_DIR/benchmarks
|
||||
|
||||
echo "[Test]: testing benchmarks ..."
|
||||
|
||||
for strategy in ${STRATEGIES[@]}; do
|
||||
torchrun --standalone --nproc_per_node 1 $BENCHMARKS_DIR/benchmark_opt_lora_dummy.py \
|
||||
--model 125m --critic_model 125m --strategy ${strategy} --lora_rank 4 \
|
||||
--num_episodes 2 --num_collect_steps 4 --num_update_steps 2 \
|
||||
--train_batch_size 2 --experience_batch_size 4
|
||||
done
|
|
@ -7,7 +7,7 @@ import torch
|
|||
import torch.distributed as dist
|
||||
from coati.models.gpt import GPTActor
|
||||
from coati.models.utils import calc_action_log_probs
|
||||
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
|
||||
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy, Strategy
|
||||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
||||
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
@ -17,40 +17,41 @@ GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4)
|
|||
|
||||
|
||||
def get_data(batch_size: int, seq_len: int = 10) -> dict:
|
||||
input_ids = torch.randint(0, 50257, (batch_size, seq_len), device='cuda')
|
||||
input_ids = torch.randint(0, 50257, (batch_size, seq_len), device="cuda")
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
|
||||
def run_test_checkpoint(strategy):
|
||||
BATCH_SIZE = 2
|
||||
def train_step(strategy: Strategy,
|
||||
actor: GPTActor,
|
||||
actor_optim: HybridAdam,
|
||||
batch_size: int = 8):
|
||||
data = get_data(batch_size)
|
||||
action_mask = torch.ones_like(data["attention_mask"], dtype=torch.bool)
|
||||
actor_output = actor(data["input_ids"], data["attention_mask"])
|
||||
action_log_probs = calc_action_log_probs(actor_output, data["input_ids"], action_mask.size(1))
|
||||
loss = action_log_probs.sum()
|
||||
strategy.backward(loss, actor, actor_optim)
|
||||
strategy.optimizer_step(actor_optim)
|
||||
|
||||
if strategy == 'ddp':
|
||||
|
||||
def run_test_checkpoint(strategy_name: str,
|
||||
shard: bool):
|
||||
if strategy_name == "ddp":
|
||||
strategy = DDPStrategy()
|
||||
elif strategy == 'colossalai_gemini':
|
||||
strategy = GeminiStrategy(placement_policy='cuda', initial_scale=2**5)
|
||||
elif strategy == 'colossalai_zero2':
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
|
||||
elif strategy_name == "colossalai_gemini":
|
||||
strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
|
||||
elif strategy_name == "colossalai_zero2":
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{strategy}"')
|
||||
raise ValueError(f"Unsupported strategy '{strategy_name}'")
|
||||
|
||||
with strategy.model_init_context():
|
||||
actor = GPTActor(config=GPT_CONFIG).cuda()
|
||||
|
||||
actor_optim = HybridAdam(actor.parameters())
|
||||
|
||||
actor, actor_optim = strategy.prepare((actor, actor_optim))
|
||||
|
||||
def run_step():
|
||||
data = get_data(BATCH_SIZE)
|
||||
action_mask = torch.ones_like(data['attention_mask'], dtype=torch.bool)
|
||||
actor_output = actor(data['input_ids'], data['attention_mask'])
|
||||
action_log_probs = calc_action_log_probs(actor_output, data['input_ids'], action_mask.size(1))
|
||||
loss = action_log_probs.sum()
|
||||
strategy.backward(loss, actor, actor_optim)
|
||||
strategy.optimizer_step(actor_optim)
|
||||
|
||||
run_step()
|
||||
train_step(strategy, actor, actor_optim)
|
||||
|
||||
ctx = tempfile.TemporaryDirectory() if dist.get_rank() == 0 else nullcontext()
|
||||
|
||||
|
@ -59,43 +60,47 @@ def run_test_checkpoint(strategy):
|
|||
dist.broadcast_object_list(rank0_dirname)
|
||||
rank0_dirname = rank0_dirname[0]
|
||||
|
||||
model_path = os.path.join(rank0_dirname, 'model.pt')
|
||||
strategy.save_model(actor, model_path, only_rank0=True)
|
||||
|
||||
optim_path = os.path.join(rank0_dirname, f'optim.pt')
|
||||
strategy.save_optimizer(actor_optim, optim_path, only_rank0=True)
|
||||
|
||||
# FIXME(cwher): Sharded optimizer checkpoint is not supported yet.
|
||||
# at "ColossalAI/colossalai/checkpoint_io/general_checkpoint_io.py", line 62
|
||||
# optim_path = os.path.join(rank0_dirname, f'optim-r{dist.get_rank()}.pt')
|
||||
# strategy.save_optimizer(actor_optim, optim_path, only_rank0=False)
|
||||
|
||||
model_path = os.path.join(
|
||||
rank0_dirname, "model" if shard else f"model.pt")
|
||||
strategy.save_model(actor, model_path, only_rank0=not shard)
|
||||
optim_path = os.path.join(
|
||||
rank0_dirname, "optim" if shard else "optim.pt")
|
||||
strategy.save_optimizer(actor_optim, optim_path, only_rank0=not shard)
|
||||
dist.barrier()
|
||||
|
||||
strategy.load_model(actor, model_path, strict=False)
|
||||
strategy.load_optimizer(actor_optim, optim_path)
|
||||
|
||||
dist.barrier()
|
||||
|
||||
run_step()
|
||||
train_step(strategy, actor, actor_optim)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, strategy):
|
||||
os.environ['RANK'] = str(rank)
|
||||
os.environ['LOCAL_RANK'] = str(rank)
|
||||
os.environ['WORLD_SIZE'] = str(world_size)
|
||||
os.environ['MASTER_ADDR'] = 'localhost'
|
||||
os.environ['MASTER_PORT'] = str(port)
|
||||
run_test_checkpoint(strategy)
|
||||
def run_dist(rank: int,
|
||||
world_size: int,
|
||||
port: int,
|
||||
strategy_name: str,
|
||||
shard: bool):
|
||||
os.environ["RANK"] = str(rank)
|
||||
os.environ["LOCAL_RANK"] = str(rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = str(port)
|
||||
run_test_checkpoint(strategy_name, shard)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [2])
|
||||
@pytest.mark.parametrize('strategy', ['ddp', 'colossalai_zero2', 'colossalai_gemini'])
|
||||
@pytest.mark.parametrize("world_size", [4])
|
||||
@pytest.mark.parametrize("strategy_name", ["ddp", "colossalai_gemini", "colossalai_zero2"])
|
||||
@pytest.mark.parametrize("shard", [False, True])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_checkpoint(world_size, strategy):
|
||||
spawn(run_dist, world_size, strategy=strategy)
|
||||
def test_checkpoint(world_size: int,
|
||||
strategy_name: str,
|
||||
shard: bool):
|
||||
spawn(run_dist,
|
||||
world_size,
|
||||
strategy_name=strategy_name,
|
||||
shard=shard)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_checkpoint(2, 'colossalai_zero2')
|
||||
if __name__ == "__main__":
|
||||
test_checkpoint(2, "colossalai_gemini", shard=False)
|
||||
|
|
|
@ -0,0 +1,248 @@
|
|||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from coati.dataset.prompt_dataset import PromptDataset
|
||||
from coati.dataset.reward_dataset import HhRlhfDataset, RmStaticDataset
|
||||
from coati.dataset.sft_dataset import IGNORE_INDEX, SFTDataset, SupervisedDataset
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer, PreTrainedTokenizer
|
||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||
|
||||
SFT_DATASET = [
|
||||
{
|
||||
"instruction": "Provide a list of the top 10 most popular mobile games in Asia",
|
||||
"input": "",
|
||||
"output": "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved",
|
||||
"id": 0
|
||||
},
|
||||
{
|
||||
"instruction": "Please provide an action plan for reducing carbon footprint on a corporate level",
|
||||
"input": "",
|
||||
"output": "An action plan for reducing carbon footprint on a corporate level could include switching to renewable energy sources (e.g., solar or wind powered), investing in energy efficient infrastructure, implementing green production practices (e.g., reduce water and paper usage, dispose of materials responsibly, recycle/reuse materials), reducing or eliminating travel, encouraging employees to be conscious of and reduce their own energy usage, and incentivizing employees to participate in sustainable work practices.",
|
||||
"id": 1
|
||||
},
|
||||
{
|
||||
"instruction": "Write a persuasive email to your boss explaining why you should have a pay raise",
|
||||
"input": "",
|
||||
"output": "Dear [Boss], \n\nI hope this message finds you well. I am writing to request a pay raise.\n\nAs you know, I have been a dedicated and hardworking member of the team since I started working here [insert number] of months/years ago. My enthusiasm and passion for my job has remained consistent over the years, and I have always given 100% to my role. \n\nI understand that the current financial situation is challenging, however, I would sincerely appreciate you taking the time to consider my request. I believe that my dedication to the job and the value that I bring to the organization warrants a raise. I work diligently and am confident that I can continue to be an asset to the company. \n\nI hope my request is taken into account and I thank you in advance for your understanding. I look forward to our conversation. \n\nSincerely,\n[Your Name]",
|
||||
"id": 2
|
||||
},
|
||||
]
|
||||
|
||||
PROMPT_DATASET = [
|
||||
{
|
||||
"instruction": "Edit this paragraph to make it more concise: \"Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends.\"",
|
||||
"id": 0
|
||||
},
|
||||
{
|
||||
"instruction": "Write a descriptive paragraph about a memorable vacation you went on",
|
||||
"id": 1
|
||||
},
|
||||
{
|
||||
"instruction": "Write a persuasive essay arguing why homework should be banned in schools",
|
||||
"id": 2
|
||||
},
|
||||
{
|
||||
"instruction": "Create a chart comparing the statistics on student debt in the United States.",
|
||||
"id": 3
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def make_tokenizer(model: str):
|
||||
if model == "gpt2":
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif model == "bloom":
|
||||
tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif model == "opt":
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif model == "llama":
|
||||
tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
else:
|
||||
raise ValueError(f"Unsupported model '{model}'")
|
||||
return tokenizer
|
||||
|
||||
|
||||
def check_content(input_ids_stripped: torch.Tensor,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
model: str):
|
||||
if model == "opt":
|
||||
# NOTE: Contrary to GPT2, OPT adds the EOS token </s> to the beginning of every prompt.
|
||||
assert input_ids_stripped[0] == tokenizer.eos_token_id
|
||||
input_ids_stripped = input_ids_stripped[1:]
|
||||
elif model == "llama":
|
||||
assert input_ids_stripped[0] == tokenizer.bos_token_id
|
||||
input_ids_stripped = input_ids_stripped[1:]
|
||||
|
||||
assert torch.all(input_ids_stripped != tokenizer.pad_token_id)
|
||||
assert torch.all(input_ids_stripped != tokenizer.bos_token_id)
|
||||
assert torch.all(input_ids_stripped != tokenizer.eos_token_id)
|
||||
assert input_ids_stripped != tokenizer.sep_token_id
|
||||
assert input_ids_stripped != tokenizer.cls_token_id
|
||||
assert input_ids_stripped != tokenizer.mask_token_id
|
||||
|
||||
|
||||
@pytest.mark.cpu
|
||||
@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"])
|
||||
@pytest.mark.parametrize("max_length", [32, 1024])
|
||||
@pytest.mark.parametrize("max_datasets_size", [2])
|
||||
def test_prompt_dataset(model: str,
|
||||
max_datasets_size: int,
|
||||
max_length: int):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
dataset_name = "prompt_dataset.json"
|
||||
with open(os.path.join(tmp_dir, dataset_name), "w") as f:
|
||||
json.dump(PROMPT_DATASET, f)
|
||||
tokenizer = make_tokenizer(model)
|
||||
assert tokenizer.padding_side in ("left", "right")
|
||||
prompt_dataset = PromptDataset(data_path=os.path.join(tmp_dir, dataset_name),
|
||||
tokenizer=tokenizer,
|
||||
max_datasets_size=max_datasets_size,
|
||||
max_length=max_length)
|
||||
assert len(prompt_dataset) == min(max_datasets_size, len(PROMPT_DATASET))
|
||||
for i in range(len(prompt_dataset)):
|
||||
assert isinstance(prompt_dataset[i], dict)
|
||||
assert list(prompt_dataset[i].keys()) == ["input_ids", "attention_mask"]
|
||||
input_ids = prompt_dataset[i]["input_ids"]
|
||||
attention_mask = prompt_dataset[i]["attention_mask"]
|
||||
attention_mask = attention_mask.bool()
|
||||
assert input_ids.shape == attention_mask.shape == torch.Size([max_length])
|
||||
assert torch.all(input_ids[torch.logical_not(attention_mask)] == tokenizer.pad_token_id)
|
||||
check_content(input_ids.masked_select(attention_mask), tokenizer, model)
|
||||
|
||||
|
||||
@pytest.mark.cpu
|
||||
@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"])
|
||||
@pytest.mark.parametrize(["dataset_path", "subset"], [
|
||||
("Anthropic/hh-rlhf", "harmless-base"),
|
||||
("Dahoas/rm-static", None)
|
||||
])
|
||||
@pytest.mark.parametrize("max_datasets_size", [32])
|
||||
@pytest.mark.parametrize("max_length", [32, 1024])
|
||||
def test_reward_dataset(model: str,
|
||||
dataset_path: str,
|
||||
subset: Optional[str],
|
||||
max_datasets_size: int,
|
||||
max_length: int):
|
||||
data = load_dataset(dataset_path, data_dir=subset)
|
||||
assert max_datasets_size <= len(data["train"]) \
|
||||
and max_datasets_size <= len(data["test"])
|
||||
train_data = data["train"].select(range(max_datasets_size))
|
||||
test_data = data["test"].select(range(max_datasets_size))
|
||||
tokenizer = make_tokenizer(model)
|
||||
assert tokenizer.padding_side in ("left", "right")
|
||||
|
||||
if dataset_path == "Anthropic/hh-rlhf":
|
||||
train_dataset = HhRlhfDataset(train_data, tokenizer, max_length)
|
||||
test_dataset = HhRlhfDataset(test_data, tokenizer, max_length)
|
||||
elif dataset_path == "Dahoas/rm-static":
|
||||
train_dataset = RmStaticDataset(train_data, tokenizer, max_length)
|
||||
test_dataset = RmStaticDataset(test_data, tokenizer, max_length)
|
||||
else:
|
||||
raise ValueError(f'Unsupported dataset "{dataset_path}"')
|
||||
|
||||
assert len(train_dataset) == len(test_dataset) == max_datasets_size
|
||||
for i in range(max_datasets_size):
|
||||
chosen_ids, c_mask, reject_ids, r_mask = train_dataset[i]
|
||||
assert chosen_ids.shape == c_mask.shape == \
|
||||
reject_ids.shape == r_mask.shape == torch.Size([max_length])
|
||||
c_mask = c_mask.to(torch.bool)
|
||||
r_mask = r_mask.to(torch.bool)
|
||||
if chosen_ids.masked_select(c_mask)[-1] == tokenizer.eos_token_id:
|
||||
check_content(chosen_ids.masked_select(c_mask)[:-1], tokenizer, model)
|
||||
assert torch.all(chosen_ids.masked_select(torch.logical_not(c_mask)) == tokenizer.pad_token_id)
|
||||
else:
|
||||
check_content(chosen_ids.masked_select(c_mask), tokenizer, model)
|
||||
assert torch.all(c_mask)
|
||||
if reject_ids.masked_select(r_mask)[-1] == tokenizer.eos_token_id:
|
||||
check_content(reject_ids.masked_select(r_mask)[:-1], tokenizer, model)
|
||||
assert torch.all(reject_ids.masked_select(torch.logical_not(r_mask)) == tokenizer.pad_token_id)
|
||||
else:
|
||||
check_content(reject_ids.masked_select(r_mask), tokenizer, model)
|
||||
assert torch.all(r_mask)
|
||||
|
||||
chosen_ids, c_mask, reject_ids, r_mask = test_dataset[i]
|
||||
assert chosen_ids.shape == c_mask.shape == \
|
||||
reject_ids.shape == r_mask.shape == torch.Size([max_length])
|
||||
c_mask = c_mask.to(torch.bool)
|
||||
r_mask = r_mask.to(torch.bool)
|
||||
if chosen_ids.masked_select(c_mask)[-1] == tokenizer.eos_token_id:
|
||||
check_content(chosen_ids.masked_select(c_mask)[:-1], tokenizer, model)
|
||||
assert torch.all(chosen_ids.masked_select(torch.logical_not(c_mask)) == tokenizer.pad_token_id)
|
||||
else:
|
||||
check_content(chosen_ids.masked_select(c_mask), tokenizer, model)
|
||||
assert torch.all(c_mask)
|
||||
if reject_ids.masked_select(r_mask)[-1] == tokenizer.eos_token_id:
|
||||
check_content(reject_ids.masked_select(r_mask)[:-1], tokenizer, model)
|
||||
assert torch.all(reject_ids.masked_select(torch.logical_not(r_mask)) == tokenizer.pad_token_id)
|
||||
else:
|
||||
check_content(reject_ids.masked_select(r_mask), tokenizer, model)
|
||||
assert torch.all(r_mask)
|
||||
|
||||
|
||||
@pytest.mark.cpu
|
||||
@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"])
|
||||
@pytest.mark.parametrize("dataset_path", ["yizhongw/self_instruct", None])
|
||||
@pytest.mark.parametrize("max_dataset_size", [2])
|
||||
@pytest.mark.parametrize("max_length", [32, 1024])
|
||||
def test_sft_dataset(model: str,
|
||||
dataset_path: Optional[str],
|
||||
max_dataset_size: int,
|
||||
max_length: int):
|
||||
tokenizer = make_tokenizer(model)
|
||||
if dataset_path == "yizhongw/self_instruct":
|
||||
data = load_dataset(dataset_path, "super_natural_instructions")
|
||||
train_data = data["train"].select(range(max_dataset_size))
|
||||
sft_dataset = SFTDataset(train_data, tokenizer, max_length)
|
||||
else:
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
dataset_name = "sft_dataset.json"
|
||||
with open(os.path.join(tmp_dir, dataset_name), "w") as f:
|
||||
json.dump(SFT_DATASET, f)
|
||||
sft_dataset = SupervisedDataset(tokenizer=tokenizer,
|
||||
data_path=os.path.join(tmp_dir, dataset_name),
|
||||
max_datasets_size=max_dataset_size,
|
||||
max_length=max_length)
|
||||
assert len(sft_dataset) == min(max_dataset_size, len(SFT_DATASET))
|
||||
|
||||
for i in range(max_dataset_size):
|
||||
assert isinstance(sft_dataset[i], dict)
|
||||
assert list(sft_dataset[i].keys()) == ["input_ids", "labels", "attention_mask"]
|
||||
input_ids = sft_dataset[i]["input_ids"]
|
||||
labels = sft_dataset[i]["labels"]
|
||||
attention_mask = sft_dataset[i]["attention_mask"].to(torch.bool)
|
||||
assert input_ids.shape == labels.shape == \
|
||||
attention_mask.shape == torch.Size([max_length])
|
||||
if input_ids.masked_select(attention_mask)[-1] == tokenizer.eos_token_id:
|
||||
check_content(input_ids.masked_select(attention_mask)[:-1], tokenizer, model)
|
||||
assert torch.all(input_ids.masked_select(torch.logical_not(attention_mask)) == tokenizer.pad_token_id)
|
||||
else:
|
||||
check_content(input_ids.masked_select(attention_mask), tokenizer, model)
|
||||
assert torch.all(attention_mask)
|
||||
ignore_mask = labels == IGNORE_INDEX
|
||||
check_content(input_ids.masked_select(ignore_mask), tokenizer, model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_sft_dataset(model="bloom",
|
||||
dataset_path="yizhongw/self_instruct",
|
||||
max_dataset_size=2,
|
||||
max_length=256)
|
||||
|
||||
test_reward_dataset(model="gpt2",
|
||||
dataset_path="Anthropic/hh-rlhf",
|
||||
subset="harmless-base",
|
||||
max_datasets_size=8,
|
||||
max_length=256)
|
||||
|
||||
test_prompt_dataset(model="opt",
|
||||
max_datasets_size=2,
|
||||
max_length=128)
|
|
@ -4,11 +4,12 @@ from copy import deepcopy
|
|||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from coati.experience_buffer import NaiveExperienceBuffer
|
||||
from coati.experience_maker import NaiveExperienceMaker
|
||||
from coati.models.base import RewardModel
|
||||
from coati.models.gpt import GPTActor, GPTCritic
|
||||
from coati.replay_buffer import NaiveReplayBuffer
|
||||
from coati.trainer.strategies import DDPStrategy, GeminiStrategy
|
||||
from coati.trainer.strategies.colossalai import LowLevelZeroStrategy
|
||||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
||||
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
@ -32,13 +33,15 @@ def gather_and_equal(tensor: torch.Tensor) -> bool:
|
|||
return True
|
||||
|
||||
|
||||
def run_test_data(strategy):
|
||||
def make_and_consume_experience(strategy):
|
||||
EXPERIENCE_BATCH_SIZE = 4
|
||||
SAMPLE_BATCH_SIZE = 2
|
||||
|
||||
if strategy == 'ddp':
|
||||
strategy = DDPStrategy()
|
||||
elif strategy == 'colossalai':
|
||||
elif strategy == 'colossalai-zero2':
|
||||
strategy = LowLevelZeroStrategy()
|
||||
elif strategy == 'colossalai-gemini':
|
||||
strategy = GeminiStrategy(placement_policy='cuda')
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{strategy}"')
|
||||
|
@ -50,7 +53,7 @@ def run_test_data(strategy):
|
|||
reward_model = RewardModel(deepcopy(critic.model)).cuda()
|
||||
|
||||
experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model)
|
||||
replay_buffer = NaiveReplayBuffer(SAMPLE_BATCH_SIZE, cpu_offload=False)
|
||||
data_buffer = NaiveExperienceBuffer(SAMPLE_BATCH_SIZE, cpu_offload=False)
|
||||
|
||||
# experience of all ranks should be the same
|
||||
for _ in range(2):
|
||||
|
@ -69,12 +72,12 @@ def run_test_data(strategy):
|
|||
assert gather_and_equal(experience.advantages)
|
||||
assert gather_and_equal(experience.action_mask)
|
||||
assert gather_and_equal(experience.attention_mask)
|
||||
replay_buffer.append(experience)
|
||||
data_buffer.append(experience)
|
||||
|
||||
# replay buffer's data should be the same
|
||||
buffer_size = torch.tensor([len(replay_buffer)], device='cuda')
|
||||
# data buffer's data should be the same
|
||||
buffer_size = torch.tensor([len(data_buffer)], device='cuda')
|
||||
assert gather_and_equal(buffer_size)
|
||||
for item in replay_buffer.items:
|
||||
for item in data_buffer.items:
|
||||
assert gather_and_equal(item.sequences)
|
||||
assert gather_and_equal(item.action_log_probs)
|
||||
assert gather_and_equal(item.values)
|
||||
|
@ -84,7 +87,7 @@ def run_test_data(strategy):
|
|||
assert gather_and_equal(item.attention_mask)
|
||||
|
||||
# dataloader of each rank should have the same size and different batch
|
||||
dataloader = strategy.setup_dataloader(replay_buffer)
|
||||
dataloader = strategy.setup_dataloader(data_buffer)
|
||||
dataloader_size = torch.tensor([len(dataloader)], device='cuda')
|
||||
assert gather_and_equal(dataloader_size)
|
||||
for experience in dataloader:
|
||||
|
@ -102,17 +105,16 @@ def run_dist(rank, world_size, port, strategy):
|
|||
os.environ['WORLD_SIZE'] = str(world_size)
|
||||
os.environ['MASTER_ADDR'] = 'localhost'
|
||||
os.environ['MASTER_PORT'] = str(port)
|
||||
run_test_data(strategy)
|
||||
make_and_consume_experience(strategy)
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [2])
|
||||
@pytest.mark.parametrize('strategy', ['ddp', 'colossalai'])
|
||||
@pytest.mark.parametrize('strategy', ['ddp', 'colossalai-zero2', 'colossalai-gemini'])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_data(world_size, strategy):
|
||||
def test_experience(world_size, strategy):
|
||||
spawn(run_dist, world_size, strategy=strategy)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_data(2, 'colossalai')
|
||||
test_experience(2, 'colossalai')
|
|
@ -0,0 +1,11 @@
|
|||
set -xue
|
||||
|
||||
BASE_DIR=$(dirname $(dirname $(realpath $BASH_SOURCE)))
|
||||
EXAMPLES_DIR=$BASE_DIR/examples
|
||||
|
||||
echo "[Test]: testing inference ..."
|
||||
|
||||
# HACK: skip llama due to oom
|
||||
for model in 'gpt2' 'bloom' 'opt'; do
|
||||
python $EXAMPLES_DIR/inference.py --model $model
|
||||
done
|
|
@ -0,0 +1,235 @@
|
|||
import copy
|
||||
from typing import Any, Callable, Dict, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from coati.models.base import Actor, Critic, RewardModel, get_base_model
|
||||
from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
|
||||
from coati.models.generation import generate
|
||||
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
|
||||
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
|
||||
from coati.models.lora import LoraLinear, convert_to_lora_module
|
||||
from coati.models.loss import GPTLMLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
|
||||
from coati.models.opt import OPTRM, OPTActor, OPTCritic
|
||||
from coati.models.utils import calc_action_log_probs, compute_reward, masked_mean
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("seq_len", [32])
|
||||
@pytest.mark.parametrize("actor_maker", [
|
||||
lambda: BLOOMActor(),
|
||||
lambda: GPTActor(),
|
||||
# HACK: skip llama due to long execution time
|
||||
# lambda: LlamaActor(),
|
||||
lambda: OPTActor()
|
||||
])
|
||||
@pytest.mark.parametrize("generate_kwargs", [{
|
||||
"max_length": 64,
|
||||
"use_cache": True,
|
||||
"do_sample": True,
|
||||
"temperature": 1.0,
|
||||
"top_k": 50,
|
||||
}])
|
||||
def test_generation(actor_maker: Callable[[], Actor],
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
generate_kwargs: Dict[str, Any]
|
||||
):
|
||||
actor = actor_maker()
|
||||
input_ids = torch.randint(0, 100, (batch_size, seq_len)).cuda()
|
||||
sequences = generate(actor.cuda(), input_ids, **generate_kwargs)
|
||||
assert sequences.shape == (batch_size, generate_kwargs["max_length"])
|
||||
|
||||
|
||||
@pytest.mark.cpu
|
||||
def test_utils():
|
||||
fn_input = {
|
||||
"tensor": torch.ones((10, )),
|
||||
"mask": torch.randint(0, 2, (10, ))
|
||||
}
|
||||
fn_output = masked_mean(dim=0, **fn_input)
|
||||
assert fn_output.dim() == 0
|
||||
assert torch.allclose(fn_output, torch.tensor(1.0))
|
||||
|
||||
batch_size = 4
|
||||
num_labels = 10
|
||||
fn_input = {
|
||||
"r": torch.ones((batch_size, )),
|
||||
"kl_coef": 1.0,
|
||||
"log_probs": torch.randn((batch_size, num_labels)),
|
||||
"log_probs_base": torch.randn((batch_size, num_labels)),
|
||||
"action_mask": torch.randint(0, 2, (batch_size, num_labels))
|
||||
}
|
||||
fn_output = compute_reward(**fn_input)
|
||||
assert fn_output.shape == (batch_size, )
|
||||
|
||||
batch_size = 4
|
||||
seq_len = 32
|
||||
num_labels = 10
|
||||
num_actions = 2
|
||||
fn_input = {
|
||||
"output": {
|
||||
"logits": torch.randn((batch_size, seq_len, num_labels))
|
||||
},
|
||||
"sequences": torch.randint(0, num_labels, (batch_size, seq_len)),
|
||||
"num_actions": num_actions,
|
||||
}
|
||||
fn_output = calc_action_log_probs(**fn_input)
|
||||
assert fn_output.shape == (batch_size, num_actions)
|
||||
|
||||
|
||||
@pytest.mark.cpu
|
||||
@pytest.mark.parametrize("lora_rank", [4])
|
||||
@pytest.mark.parametrize("num_dim", [32])
|
||||
@pytest.mark.parametrize("num_layers", [4])
|
||||
def test_lora(lora_rank: int,
|
||||
num_dim: int,
|
||||
num_layers: int):
|
||||
model = nn.ModuleList(
|
||||
[nn.Linear(num_dim, num_dim)
|
||||
for _ in range(num_layers)]
|
||||
)
|
||||
lora_model = convert_to_lora_module(model, lora_rank)
|
||||
assert isinstance(lora_model, nn.ModuleList)
|
||||
for i in range(num_layers):
|
||||
assert isinstance(lora_model[i], LoraLinear)
|
||||
assert lora_model[i].lora_A.shape == (lora_rank, num_dim)
|
||||
assert lora_model[i].lora_B.shape == (num_dim, lora_rank)
|
||||
|
||||
old_model = copy.deepcopy(lora_model)
|
||||
for i in range(num_layers):
|
||||
assert isinstance(lora_model[i], LoraLinear)
|
||||
assert torch.allclose(old_model[i].weight, lora_model[i].weight)
|
||||
assert torch.allclose(old_model[i].bias, lora_model[i].bias)
|
||||
assert torch.allclose(old_model[i].lora_B @ old_model[i].lora_A,
|
||||
lora_model[i].lora_B @ lora_model[i].lora_A)
|
||||
optimizer = torch.optim.Adam(lora_model.parameters())
|
||||
x = torch.randn(8, num_dim)
|
||||
for i in range(num_layers):
|
||||
x = lora_model[i](x)
|
||||
loss = x.sum()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
for i in range(num_layers):
|
||||
assert isinstance(lora_model[i], LoraLinear)
|
||||
assert torch.allclose(old_model[i].weight, lora_model[i].weight)
|
||||
assert torch.allclose(old_model[i].bias, lora_model[i].bias)
|
||||
assert not torch.allclose(old_model[i].lora_B @ old_model[i].lora_A,
|
||||
lora_model[i].lora_B @ lora_model[i].lora_A)
|
||||
|
||||
|
||||
@pytest.mark.cpu
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seq_len", [128])
|
||||
@pytest.mark.parametrize("models_maker", [
|
||||
lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()),
|
||||
lambda: (GPTActor(), GPTCritic(), GPTRM()),
|
||||
# HACK: skip llama due to long execution time
|
||||
# lambda: (LlamaActor(), LlamaCritic(), LlamaRM()),
|
||||
lambda: (OPTActor(), OPTCritic(), OPTRM()),
|
||||
])
|
||||
@torch.no_grad()
|
||||
def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]],
|
||||
batch_size: int,
|
||||
seq_len: int):
|
||||
|
||||
actor_input = {
|
||||
"input_ids": torch.randint(0, 100, (batch_size, seq_len)),
|
||||
"attention_mask": torch.randint(0, 2, (batch_size, seq_len))
|
||||
}
|
||||
critic_input = {
|
||||
"sequences": torch.randint(0, 100, (batch_size, seq_len)),
|
||||
"action_mask": torch.randint(0, 2, (batch_size, seq_len)),
|
||||
"attention_mask": torch.randint(0, 2, (batch_size, seq_len))
|
||||
}
|
||||
rm_input = {
|
||||
"sequences": torch.randint(0, 100, (batch_size, seq_len)),
|
||||
"attention_mask": torch.randint(0, 2, (batch_size, seq_len))
|
||||
}
|
||||
|
||||
actor, critic, rm = models_maker()
|
||||
assert isinstance(actor, Actor)
|
||||
base_actor_model = get_base_model(actor)
|
||||
assert isinstance(critic, Critic)
|
||||
base_critic_model = get_base_model(critic)
|
||||
assert isinstance(rm, RewardModel)
|
||||
base_rm_model = get_base_model(rm)
|
||||
|
||||
actor_output = actor(**actor_input)
|
||||
critic_output = critic(**critic_input)
|
||||
rm_output = rm(**rm_input)
|
||||
|
||||
assert actor_output.logits.shape[:2] == (batch_size, seq_len)
|
||||
assert critic_output.shape == (batch_size, )
|
||||
assert rm_output.shape == (batch_size, )
|
||||
|
||||
|
||||
@pytest.mark.cpu
|
||||
@pytest.mark.parametrize("batch_size", [16])
|
||||
@pytest.mark.parametrize("seq_len", [128])
|
||||
@pytest.mark.parametrize("num_labels", [100])
|
||||
def test_loss(batch_size: int,
|
||||
seq_len: int,
|
||||
num_labels: int):
|
||||
loss = GPTLMLoss()
|
||||
loss_input = {
|
||||
"logits": torch.randn(batch_size, seq_len, num_labels),
|
||||
"labels": torch.randint(0, num_labels, (batch_size, seq_len))
|
||||
}
|
||||
loss_output = loss(**loss_input)
|
||||
|
||||
loss = PolicyLoss()
|
||||
loss_input = {
|
||||
"log_probs": torch.randn(batch_size, ),
|
||||
"old_log_probs": torch.randn(batch_size, ),
|
||||
"advantages": torch.randn(batch_size, )
|
||||
}
|
||||
loss_output = loss(**loss_input)
|
||||
|
||||
loss = ValueLoss()
|
||||
loss_input = {
|
||||
"values": torch.randn(batch_size, ),
|
||||
"old_values": torch.randn(batch_size, ),
|
||||
"reward": torch.randn(batch_size, )
|
||||
}
|
||||
loss_output = loss(**loss_input)
|
||||
|
||||
loss = LogSigLoss()
|
||||
loss_input = {
|
||||
"chosen_reward": torch.randn(batch_size, ),
|
||||
"reject_reward": torch.randn(batch_size, ),
|
||||
}
|
||||
loss_output = loss(**loss_input)
|
||||
|
||||
loss = LogExpLoss()
|
||||
loss_input = {
|
||||
"chosen_reward": torch.randn(batch_size, ),
|
||||
"reject_reward": torch.randn(batch_size, ),
|
||||
}
|
||||
loss_output = loss(**loss_input)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_kwargs = dict(max_length=40,
|
||||
use_cache=True,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50)
|
||||
test_generation(lambda: LlamaActor(),
|
||||
batch_size=4,
|
||||
seq_len=32,
|
||||
generate_kwargs=generate_kwargs)
|
||||
|
||||
test_utils()
|
||||
|
||||
test_lora(lora_rank=2, num_dim=8, num_layers=2)
|
||||
|
||||
test_models(models_maker=lambda: (BLOOMActor(),
|
||||
BLOOMCritic(),
|
||||
BLOOMRM()),
|
||||
batch_size=8,
|
||||
seq_len=128)
|
||||
|
||||
test_loss(batch_size=8, seq_len=128, num_labels=100)
|
|
@ -0,0 +1,228 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
|
||||
tail -n +2 |
|
||||
nl -v 0 |
|
||||
tee /dev/tty |
|
||||
sort -g -k 2 |
|
||||
awk '{print $1}' |
|
||||
head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 4
|
||||
|
||||
set -xu
|
||||
|
||||
if [ -z "$SFT_DATASET" ]; then
|
||||
echo "Please set \$SFT_DATASET to the path to sft dataset."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -z "$PROMPT_PATH" ]; then
|
||||
echo "Please set \$PROMPT_PATH to the path to prompts csv."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -z "$PRETRAIN_DATASET" ]; then
|
||||
echo "Please set \$PRETRAIN_DATASET to the path to alpaca data."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
NUM_RETRY=3
|
||||
BASE_DIR=$(dirname $(dirname $(realpath $BASH_SOURCE)))
|
||||
EXAMPLES_DIR=$BASE_DIR/examples
|
||||
MODELS_DIR=$BASE_DIR/examples/models_config
|
||||
MODELS=('gpt2' 'bloom' 'opt' 'llama')
|
||||
STRATEGIES=('ddp' 'colossalai_gemini' 'colossalai_zero2')
|
||||
|
||||
export OMP_NUM_THREADS=8
|
||||
|
||||
# install requirements
|
||||
pip install -r $EXAMPLES_DIR/requirements.txt
|
||||
|
||||
python $EXAMPLES_DIR/download_model.py --model-dir $MODELS_DIR --config-only
|
||||
|
||||
get_pretrain() {
|
||||
local model=$1
|
||||
if [[ $model == "gpt2" ]]; then
|
||||
echo "gpt2"
|
||||
elif [[ $model == "bloom" ]]; then
|
||||
echo "bigscience/bloom-560m"
|
||||
elif [[ $model == "opt" ]]; then
|
||||
echo "facebook/opt-350m"
|
||||
else
|
||||
echo "Unknown model $model"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
random_choice() {
|
||||
local arr=("$@")
|
||||
local len=${#arr[@]}
|
||||
local idx=$((RANDOM % len))
|
||||
echo ${arr[$idx]}
|
||||
}
|
||||
|
||||
echo "[Test]: testing sft ..."
|
||||
|
||||
# FIXME: This is a hack to skip tests that are not working
|
||||
# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
|
||||
# - llama-*: These tests can be passed locally, skipped for long execution time
|
||||
SKIPPED_TESTS=(
|
||||
"gpt2-ddp"
|
||||
"llama-ddp"
|
||||
"llama-colossalai_gemini"
|
||||
"llama-colossalai_zero2"
|
||||
)
|
||||
|
||||
GRAD_CKPTS=('' '--grad_checkpoint')
|
||||
for lora_rank in '0' '4'; do
|
||||
for model in ${MODELS[@]}; do
|
||||
strategies=($(shuf -e "${STRATEGIES[@]}"))
|
||||
for strategy in ${strategies[@]}; do
|
||||
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy-$lora_rank " ]]; then
|
||||
echo "[Test]: Skipped $model-$strategy-$lora_rank"
|
||||
continue
|
||||
elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy " ]]; then
|
||||
echo "[Test]: Skipped $model-$strategy"
|
||||
continue
|
||||
fi
|
||||
pretrain=$(get_pretrain $model)
|
||||
pretrain_model=""
|
||||
if [[ $lora_rank -gt 0 ]]; then
|
||||
pretrain_model="--pretrain $pretrain"
|
||||
fi
|
||||
grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
|
||||
for i in $(seq $NUM_RETRY); do
|
||||
echo "[Test]: $model-$strategy-$lora_rank, attempt $i"
|
||||
torchrun --standalone --nproc_per_node=4 $EXAMPLES_DIR/train_sft.py \
|
||||
$pretrain_model --tokenizer $MODELS_DIR/$model \
|
||||
--model $model --strategy $strategy --lora_rank $lora_rank $grad_ckpt \
|
||||
--dataset $SFT_DATASET --max_datasets_size 8 \
|
||||
--max_epochs 1 --batch_size 1 --accumulation_steps 1 \
|
||||
--save_path $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank}
|
||||
passed=$?
|
||||
if [ $passed -eq 0 ]; then
|
||||
break
|
||||
fi
|
||||
done
|
||||
if [ $passed -ne 0 ]; then
|
||||
echo "[Test]: Failed $model-$strategy-$lora_rank"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
echo "[Test]: testing reward model ..."
|
||||
|
||||
# FIXME: This is a hack to skip tests that are not working
|
||||
# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
|
||||
# - llama-*: These tests can be passed locally, skipped for long execution time
|
||||
SKIPPED_TESTS=(
|
||||
"gpt2-ddp"
|
||||
"llama-ddp"
|
||||
"llama-colossalai_gemini"
|
||||
"llama-colossalai_zero2"
|
||||
)
|
||||
|
||||
LOSS_FNS=('log_sig' 'log_exp')
|
||||
DATASETS=('Anthropic/hh-rlhf' 'Dahoas/rm-static')
|
||||
for lora_rank in '0' '4'; do
|
||||
for model in ${MODELS[@]}; do
|
||||
strategies=($(shuf -e "${STRATEGIES[@]}"))
|
||||
for strategy in ${strategies[@]}; do
|
||||
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy-$lora_rank " ]]; then
|
||||
echo "[Test]: Skipped $model-$strategy-$lora_rank"
|
||||
continue
|
||||
elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy " ]]; then
|
||||
echo "[Test]: Skipped $model-$strategy"
|
||||
continue
|
||||
fi
|
||||
pretrain=$(get_pretrain $model)
|
||||
pretrain_model=""
|
||||
if [[ $lora_rank -gt 0 ]]; then
|
||||
pretrain_model="--pretrain $pretrain"
|
||||
fi
|
||||
loss_fn=$(random_choice "${LOSS_FNS[@]}")
|
||||
dataset=$(random_choice "${DATASETS[@]}")
|
||||
subset=$(if [[ $dataset == "Dahoas/rm-static" ]]; then echo "None"; else echo "harmless-base"; fi)
|
||||
for i in $(seq $NUM_RETRY); do
|
||||
echo "[Test]: $model-$strategy-$lora_rank, attempt $i"
|
||||
torchrun --standalone --nproc_per_node=4 $EXAMPLES_DIR/train_reward_model.py \
|
||||
$pretrain_model --tokenizer $MODELS_DIR/$model \
|
||||
--model $model --strategy $strategy --lora_rank $lora_rank --loss_fn $loss_fn \
|
||||
--dataset $dataset --subset $subset --test True --batch_size 1 \
|
||||
--save_path $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt
|
||||
passed=$?
|
||||
if [ $passed -eq 0 ]; then
|
||||
break
|
||||
fi
|
||||
done
|
||||
if [ $passed -ne 0 ]; then
|
||||
echo "[Test]: Failed to train reward model $model-$strategy-$lora_rank"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
echo "[Test]: testing RLHF ..."
|
||||
|
||||
# FIXME: This is a hack to skip tests that are not working
|
||||
# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
|
||||
# - llama-*: These tests can be passed locally, skipped for long execution time
|
||||
SKIPPED_TESTS=(
|
||||
"gpt2-ddp"
|
||||
"llama-ddp"
|
||||
"llama-colossalai_gemini"
|
||||
"llama-colossalai_zero2"
|
||||
)
|
||||
|
||||
for model in ${MODELS[@]}; do
|
||||
for lora_rank in '0' '4'; do
|
||||
strategies=($(shuf -e "${STRATEGIES[@]}"))
|
||||
for strategy in ${strategies[@]}; do
|
||||
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy-$lora_rank " ]]; then
|
||||
echo "[Test]: Skipped $model-$strategy-$lora_rank"
|
||||
continue
|
||||
elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy " ]]; then
|
||||
echo "[Test]: Skipped $model-$strategy"
|
||||
continue
|
||||
fi
|
||||
rm_pretrain=$(get_pretrain $model)
|
||||
rm_pretrain_model=""
|
||||
if [[ $lora_rank -gt 0 ]]; then
|
||||
rm_pretrain_model="--rm_pretrain $rm_pretrain"
|
||||
fi
|
||||
for i in $(seq $NUM_RETRY); do
|
||||
echo "[Test]: $model-$strategy-$lora_rank, attempt $i"
|
||||
torchrun --standalone --nproc_per_node=4 $EXAMPLES_DIR/train_prompts.py \
|
||||
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
|
||||
--strategy $strategy --model $model --tokenizer $MODELS_DIR/$model \
|
||||
--num_episodes 1 --num_collect_steps 1 --num_update_steps 1 \
|
||||
--experience_batch_size 2 --train_batch_size 1 --lora_rank $lora_rank \
|
||||
--pretrain $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank} \
|
||||
$rm_pretrain_model --rm_path $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt \
|
||||
--save_path $EXAMPLES_DIR/rlhf_models/actor_checkpoint_prompts.pt
|
||||
passed=$?
|
||||
if [ $passed -eq 0 ]; then
|
||||
break
|
||||
fi
|
||||
done
|
||||
if [ $passed -ne 0 ]; then
|
||||
echo "[Test]: Failed to train RLHF $model-$strategy-$lora_rank"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
rm -rf $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank}
|
||||
rm $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt
|
||||
done
|
||||
done
|
||||
rm $EXAMPLES_DIR/rlhf_models/actor_checkpoint_prompts.pt
|
Loading…
Reference in New Issue