mirror of https://github.com/hpcaitech/ColossalAI
[chat]: update rm, add wandb and fix bugs (#4471)
* feat: modify forward fn of critic and reward model * feat: modify calc_action_log_probs * to: add wandb in sft and rm trainer * feat: update train_sft * feat: update train_rm * style: modify type annotation and add warning * feat: pass tokenizer to ppo trainer * to: modify trainer base and maker base * feat: add wandb in ppo trainer * feat: pass tokenizer to generate * test: update generate fn tests * test: update train tests * fix: remove action_mask * feat: remove unused code * fix: fix wrong ignore_index * fix: fix mock tokenizer * chore: update requirements * revert: modify make_experience * fix: fix inference * fix: add padding side * style: modify _on_learn_batch_end * test: use mock tokenizer * fix: use bf16 to avoid overflow * fix: fix workflow * [chat] fix gemini strategy * [chat] fix * sync: update colossalai strategy * fix: fix args and model dtype * fix: fix checkpoint test * fix: fix requirements * fix: fix missing import and wrong arg * fix: temporarily skip gemini test in stage 3 * style: apply pre-commit * fix: temporarily skip gemini test in stage 1&2 --------- Co-authored-by: Mingyan Jiang <1829166702@qq.com>pull/4766/head
parent
07c2e3d09c
commit
7b9b86441f
|
@ -49,5 +49,5 @@ jobs:
|
|||
NCCL_SHM_DISABLE: 1
|
||||
MAX_JOBS: 8
|
||||
SFT_DATASET: /data/scratch/github_actions/chat/data.json
|
||||
PROMPT_PATH: /data/scratch/github_actions/chat/prompts_en.jsonl
|
||||
PROMPT_DATASET: /data/scratch/github_actions/chat/prompts_en.jsonl
|
||||
PRETRAIN_DATASET: /data/scratch/github_actions/chat/alpaca_data.json
|
||||
|
|
|
@ -138,6 +138,7 @@ def main(args):
|
|||
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
(actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim))
|
||||
|
||||
|
@ -154,6 +155,7 @@ def main(args):
|
|||
initial_model,
|
||||
actor_optim,
|
||||
critic_optim,
|
||||
tokenizer=tokenizer,
|
||||
ptx_coef=0,
|
||||
train_batch_size=args.train_batch_size,
|
||||
offload_inference_models=args.offload_inference_models,
|
||||
|
@ -162,8 +164,6 @@ def main(args):
|
|||
temperature=1.0,
|
||||
top_k=50,
|
||||
use_cache=True,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
callbacks=[performance_evaluator],
|
||||
)
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
from typing import Dict, Sequence, Tuple
|
||||
from typing import Dict, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
|
||||
|
@ -57,6 +57,7 @@ def _preprocess(
|
|||
sources, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||
)
|
||||
|
||||
assert sequences_token["attention_mask"].dim() == 2, "seq2seq model should be preprocessed differently"
|
||||
labels = copy.deepcopy(sequences_token["input_ids"])
|
||||
for i in range(labels.shape[0]):
|
||||
source_len = sources_token["attention_mask"][i].sum().item()
|
||||
|
@ -64,9 +65,10 @@ def _preprocess(
|
|||
if tokenizer.padding_side == "right":
|
||||
# |prompt|completion|eos|pad|
|
||||
labels[i][:source_len] = IGNORE_INDEX
|
||||
labels[i][-pad_len:] = IGNORE_INDEX
|
||||
elif tokenizer.padding_side == "left":
|
||||
# |pad|prompt|completion|eos|
|
||||
labels[i][pad_len : pad_len + source_len] = IGNORE_INDEX
|
||||
labels[i][: pad_len + source_len] = IGNORE_INDEX
|
||||
else:
|
||||
raise RuntimeError()
|
||||
|
||||
|
@ -126,6 +128,8 @@ class SFTDataset(Dataset):
|
|||
|
||||
sources = [data["prompt"] for data in dataset]
|
||||
targets = [data["completion"] + tokenizer.eos_token for data in tqdm(dataset, disable=not is_rank_0())]
|
||||
|
||||
logger.info("Tokenizing inputs... This may take some time...")
|
||||
if isinstance(tokenizer, ChatGLMTokenizer):
|
||||
self.input_ids, self.labels, self.attention_mask = _preprocess_chatglm(
|
||||
sources, targets, tokenizer, max_length
|
||||
|
@ -133,6 +137,8 @@ class SFTDataset(Dataset):
|
|||
else:
|
||||
self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length)
|
||||
|
||||
logger.info("Loaded dataset.")
|
||||
|
||||
def __len__(self):
|
||||
length = self.input_ids.shape[0]
|
||||
return length
|
||||
|
@ -148,7 +154,11 @@ class SupervisedDataset(Dataset):
|
|||
"""Dataset for supervised fine-tuning."""
|
||||
|
||||
def __init__(
|
||||
self, data_path: str, tokenizer: PreTrainedTokenizer, max_datasets_size: int = None, max_length: int = 512
|
||||
self,
|
||||
data_path: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
max_datasets_size: Optional[int] = None,
|
||||
max_length: int = 512,
|
||||
):
|
||||
super().__init__()
|
||||
logger.info("Loading data...")
|
||||
|
@ -175,6 +185,8 @@ class SupervisedDataset(Dataset):
|
|||
else:
|
||||
self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length)
|
||||
|
||||
logger.info("Loaded dataset.")
|
||||
|
||||
def __len__(self):
|
||||
length = self.input_ids.shape[0]
|
||||
return length
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import random
|
||||
import warnings
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
@ -30,9 +31,11 @@ class NaiveExperienceBuffer(ExperienceBuffer):
|
|||
experience.to_device(torch.device("cpu"))
|
||||
items = split_experience_batch(experience)
|
||||
self.items.extend(items)
|
||||
|
||||
if self.limit > 0:
|
||||
samples_to_remove = len(self.items) - self.limit
|
||||
if samples_to_remove > 0:
|
||||
warnings.warn(f"Experience buffer is full. Removing {samples_to_remove} samples.")
|
||||
self.items = self.items[samples_to_remove:]
|
||||
|
||||
def clear(self) -> None:
|
||||
|
|
|
@ -3,8 +3,7 @@ from dataclasses import dataclass
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from coati.models.base import Actor
|
||||
from coati.models.base import Actor, Critic, RewardModel
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -59,16 +58,13 @@ class Experience:
|
|||
|
||||
|
||||
class ExperienceMaker(ABC):
|
||||
def __init__(
|
||||
self, actor: Actor, critic: nn.Module, reward_model: nn.Module, initial_model: Actor, kl_coef: float = 0.1
|
||||
) -> None:
|
||||
def __init__(self, actor: Actor, critic: Critic, reward_model: RewardModel, initial_model: Actor) -> None:
|
||||
super().__init__()
|
||||
self.actor = actor
|
||||
self.critic = critic
|
||||
self.reward_model = reward_model
|
||||
self.initial_model = initial_model
|
||||
self.kl_coef = kl_coef
|
||||
|
||||
@abstractmethod
|
||||
def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experience:
|
||||
def make_experience(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **generate_kwargs) -> Experience:
|
||||
pass
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from coati.models.base import Actor, Critic, RewardModel
|
||||
from coati.models.generation import generate
|
||||
from coati.models.utils import calc_action_log_probs, compute_reward
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from .base import Experience, ExperienceMaker
|
||||
|
||||
|
@ -11,6 +13,19 @@ class NaiveExperienceMaker(ExperienceMaker):
|
|||
Naive experience maker.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
actor: Actor,
|
||||
critic: Critic,
|
||||
reward_model: RewardModel,
|
||||
initial_model: Actor,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
kl_coef: float = 0.1,
|
||||
) -> None:
|
||||
super().__init__(actor, critic, reward_model, initial_model)
|
||||
self.tokenizer = tokenizer
|
||||
self.kl_coef = kl_coef
|
||||
|
||||
@torch.no_grad()
|
||||
def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experience:
|
||||
self.actor.eval()
|
||||
|
@ -19,16 +34,16 @@ class NaiveExperienceMaker(ExperienceMaker):
|
|||
self.reward_model.eval()
|
||||
|
||||
# generate sequences
|
||||
sequences = generate(self.actor, input_ids, **generate_kwargs)
|
||||
sequences = generate(self.actor, input_ids, self.tokenizer, **generate_kwargs)
|
||||
|
||||
# calculate auxiliary tensors
|
||||
attention_mask = None
|
||||
pad_token_id = generate_kwargs.get("pad_token_id", None)
|
||||
pad_token_id = self.tokenizer.pad_token_id
|
||||
if pad_token_id is not None:
|
||||
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
|
||||
|
||||
input_len = input_ids.size(1)
|
||||
eos_token_id = generate_kwargs.get("eos_token_id", None)
|
||||
eos_token_id = self.tokenizer.eos_token_id
|
||||
if eos_token_id is None:
|
||||
action_mask = torch.ones_like(sequences, dtype=torch.bool)
|
||||
else:
|
||||
|
@ -40,11 +55,11 @@ class NaiveExperienceMaker(ExperienceMaker):
|
|||
action_mask = action_mask[:, -(sequences.size(1) - input_len) :]
|
||||
num_actions = action_mask.size(1)
|
||||
|
||||
actor_output = self.actor(sequences, attention_mask)
|
||||
actor_output = self.actor(sequences, attention_mask)["logits"]
|
||||
action_log_probs = calc_action_log_probs(actor_output, sequences, num_actions)
|
||||
base_model_output = self.initial_model(sequences, attention_mask)
|
||||
base_model_output = self.initial_model(sequences, attention_mask)["logits"]
|
||||
base_action_log_probs = calc_action_log_probs(base_model_output, sequences, num_actions)
|
||||
value = self.critic(sequences, action_mask, attention_mask)
|
||||
value = self.critic(sequences, attention_mask)
|
||||
r = self.reward_model(sequences, attention_mask)
|
||||
reward = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask)
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ class Actor(LoRAModule):
|
|||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**model_kwargs, # HACK: `generate` method may pass more kwargs
|
||||
**model_kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Returns model output."""
|
||||
output = self.model(input_ids, attention_mask=attention_mask, **model_kwargs)
|
||||
|
|
|
@ -1,10 +1,7 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..lora import LoRAModule
|
||||
from ..utils import masked_mean
|
||||
|
||||
|
||||
class Critic(LoRAModule):
|
||||
|
@ -19,37 +16,19 @@ class Critic(LoRAModule):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
value_head: nn.Module,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = "none",
|
||||
use_action_mask: bool = False,
|
||||
self, model: nn.Module, value_head: nn.Module, lora_rank: int = 0, lora_train_bias: str = "none"
|
||||
) -> None:
|
||||
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
|
||||
self.model = model
|
||||
self.value_head = value_head
|
||||
self.use_action_mask = use_action_mask
|
||||
self.convert_to_lora()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sequences: torch.LongTensor,
|
||||
action_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
def forward(self, sequences: torch.LongTensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
||||
outputs = self.model(sequences, attention_mask=attention_mask)
|
||||
last_hidden_states = outputs["last_hidden_state"]
|
||||
|
||||
values = self.value_head(last_hidden_states).squeeze(-1)
|
||||
|
||||
if action_mask is not None and self.use_action_mask:
|
||||
num_actions = action_mask.size(1)
|
||||
prompt_mask = attention_mask[:, :-num_actions]
|
||||
values = values[:, :-num_actions]
|
||||
value = masked_mean(values, prompt_mask, dim=1)
|
||||
return value
|
||||
|
||||
values = values[:, :-1]
|
||||
value = values.mean(dim=1)
|
||||
return value
|
||||
sequence_lengths = torch.max(attention_mask * torch.arange(sequences.size(1), device=sequences.device), dim=1)[
|
||||
0
|
||||
]
|
||||
sequence_hidden_states = last_hidden_states[torch.arange(last_hidden_states.size(0)), sequence_lengths]
|
||||
values = self.value_head(sequence_hidden_states).squeeze(1) # ensure shape is (B, )
|
||||
return values
|
||||
|
|
|
@ -35,9 +35,12 @@ class RewardModel(LoRAModule):
|
|||
else:
|
||||
self.value_head = nn.Linear(model.config.n_embd, 1)
|
||||
|
||||
def forward(self, sequences: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def forward(self, sequences: torch.LongTensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
||||
outputs = self.model(sequences, attention_mask=attention_mask)
|
||||
last_hidden_states = outputs["last_hidden_state"]
|
||||
values = self.value_head(last_hidden_states)[:, :-1]
|
||||
value = values.mean(dim=1).squeeze(1) # ensure shape is (B)
|
||||
return value
|
||||
sequence_lengths = torch.max(attention_mask * torch.arange(sequences.size(1), device=sequences.device), dim=1)[
|
||||
0
|
||||
]
|
||||
sequence_hidden_states = last_hidden_states[torch.arange(last_hidden_states.size(0)), sequence_lengths]
|
||||
values = self.value_head(sequence_hidden_states).squeeze(1) # ensure shape is (B, )
|
||||
return values
|
||||
|
|
|
@ -2,6 +2,7 @@ from typing import Any, Callable, Optional
|
|||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from .base import Actor
|
||||
|
||||
|
@ -63,8 +64,8 @@ def _sample(
|
|||
)
|
||||
outputs = model(**model_inputs)
|
||||
|
||||
# NOTE: this is correct only in left padding mode
|
||||
next_token_logits = outputs["logits"][:, -1, :]
|
||||
# pre-process distribution
|
||||
next_token_logits = logits_processor(input_ids, next_token_logits)
|
||||
# sample
|
||||
probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float)
|
||||
|
@ -72,8 +73,7 @@ def _sample(
|
|||
|
||||
# finished sentences should have their next token be a padding token
|
||||
if eos_token_id is not None:
|
||||
if pad_token_id is None:
|
||||
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
||||
assert pad_token_id is not None, "If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
|
||||
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
||||
|
||||
# update generated ids, model inputs for next step
|
||||
|
@ -96,12 +96,11 @@ def _sample(
|
|||
def generate(
|
||||
model: Actor,
|
||||
input_ids: torch.Tensor,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
max_length: int,
|
||||
num_beams: int = 1,
|
||||
do_sample: bool = True,
|
||||
early_stopping: bool = False,
|
||||
eos_token_id: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None,
|
||||
|
@ -118,14 +117,13 @@ def generate(
|
|||
num_beams (int, optional): number of beams. Defaults to 1.
|
||||
do_sample (bool, optional): whether to do sample. Defaults to True.
|
||||
early_stopping (bool, optional): if True, the sequence length may be smaller than max_length due to finding eos. Defaults to False.
|
||||
eos_token_id (Optional[int], optional): end of sequence token id. Defaults to None.
|
||||
pad_token_id (Optional[int], optional): pad token id. Defaults to None.
|
||||
top_k (Optional[int], optional): the number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to None.
|
||||
top_p (Optional[float], optional): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. Defaults to None.
|
||||
temperature (Optional[float], optional): The value used to module the next token probabilities. Defaults to None.
|
||||
prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None.
|
||||
update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None.
|
||||
"""
|
||||
assert tokenizer.padding_side == "left", "Current generation only supports left padding."
|
||||
is_greedy_gen_mode = (num_beams == 1) and do_sample is False
|
||||
is_sample_gen_mode = (num_beams == 1) and do_sample is True
|
||||
is_beam_gen_mode = (num_beams > 1) and do_sample is False
|
||||
|
@ -139,8 +137,8 @@ def generate(
|
|||
input_ids,
|
||||
max_length,
|
||||
early_stopping=early_stopping,
|
||||
eos_token_id=eos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
|
|
|
@ -13,6 +13,7 @@ class GPTLMLoss(nn.Module):
|
|||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# NOTE: default ignore_index is -100, which is equal to IGNORE_INDEX in sft_dataset.py
|
||||
self.loss = nn.CrossEntropyLoss()
|
||||
|
||||
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
||||
|
|
|
@ -46,18 +46,17 @@ def _log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.
|
|||
return log_probs_labels.squeeze(-1)
|
||||
|
||||
|
||||
def calc_action_log_probs(output: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor:
|
||||
def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor:
|
||||
"""Calculate action log probs.
|
||||
|
||||
Args:
|
||||
output (torch.Tensor): Output tensor of Actor.forward.
|
||||
output (torch.Tensor): Output tensor of Actor.forward.logits.
|
||||
sequences (torch.LongTensor): Input sequences.
|
||||
num_actions (int): Number of actions.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Action log probs.
|
||||
"""
|
||||
logits = output["logits"]
|
||||
log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
|
||||
return log_probs[:, -num_actions:]
|
||||
|
||||
|
|
|
@ -41,13 +41,13 @@ def get_actor_from_args(model: str, pretrained: str = None, config=None, lora_ra
|
|||
|
||||
def get_critic_from_args(model: str, pretrained: str = None, config=None, lora_rank=0):
|
||||
if model == "gpt2":
|
||||
critic = GPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
|
||||
critic = GPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config)
|
||||
elif model == "bloom":
|
||||
critic = BLOOMCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
|
||||
critic = BLOOMCritic(pretrained=pretrained, lora_rank=lora_rank, config=config)
|
||||
elif model == "opt":
|
||||
critic = OPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
|
||||
critic = OPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config)
|
||||
elif model == "llama":
|
||||
critic = LlamaCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
|
||||
critic = LlamaCritic(pretrained=pretrained, lora_rank=lora_rank, config=config)
|
||||
else:
|
||||
raise ValueError(f'Unsupported reward model "{model}"')
|
||||
return critic
|
||||
|
|
|
@ -7,11 +7,10 @@ import tqdm
|
|||
from coati.experience_buffer import NaiveExperienceBuffer
|
||||
from coati.experience_maker import Experience
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .callbacks import Callback
|
||||
from .strategies import Strategy
|
||||
from .utils import CycledDataLoader, is_rank_0
|
||||
from .utils import is_rank_0
|
||||
|
||||
|
||||
class SLTrainer(ABC):
|
||||
|
@ -47,11 +46,11 @@ class SLTrainer(ABC):
|
|||
raise NotImplementedError()
|
||||
|
||||
def _before_fit(self):
|
||||
self.no_epoch_bar = False
|
||||
raise NotImplementedError()
|
||||
|
||||
def fit(self, *args, **kwargs):
|
||||
self._before_fit(*args, **kwargs)
|
||||
for epoch in tqdm.trange(self.max_epochs, desc="Epochs", disable=not is_rank_0() or self.no_epoch_bar):
|
||||
for epoch in tqdm.trange(self.max_epochs, desc="Epochs", disable=not is_rank_0()):
|
||||
self._train(epoch)
|
||||
self._eval(epoch)
|
||||
|
||||
|
@ -123,9 +122,9 @@ class OnPolicyTrainer(ABC):
|
|||
for callback in self.callbacks:
|
||||
callback.on_learn_batch_start()
|
||||
|
||||
def _on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
|
||||
def _on_learn_batch_end(self, experience: Experience) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_learn_batch_end(metrics, experience)
|
||||
callback.on_learn_batch_end(experience)
|
||||
|
||||
@abstractmethod
|
||||
def _make_experience(self, collect_step: int):
|
||||
|
@ -153,27 +152,26 @@ class OnPolicyTrainer(ABC):
|
|||
self._learn(update_step)
|
||||
self._on_learn_epoch_end(update_step)
|
||||
|
||||
def _before_fit(self, *args, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
def fit(
|
||||
self,
|
||||
prompt_dataloader: DataLoader,
|
||||
pretrain_dataloader: DataLoader,
|
||||
num_episodes: int,
|
||||
num_collect_steps: int,
|
||||
num_update_steps: int,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
The main training loop of on-policy rl trainers.
|
||||
|
||||
Args:
|
||||
prompt_dataloader (DataLoader): the dataloader to use for prompt data
|
||||
pretrain_dataloader (DataLoader): the dataloader to use for pretrain data
|
||||
num_episodes (int): the number of episodes to train
|
||||
num_collect_steps (int): the number of collect steps per episode
|
||||
num_update_steps (int): the number of update steps per episode
|
||||
"""
|
||||
self.prompt_dataloader = CycledDataLoader(prompt_dataloader)
|
||||
self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader)
|
||||
|
||||
self._before_fit(*args, **kwargs)
|
||||
with self._fit_ctx():
|
||||
for episode in tqdm.trange(num_episodes, desc="Episodes", disable=not is_rank_0()):
|
||||
with self._episode_ctx(episode):
|
||||
|
|
|
@ -35,5 +35,5 @@ class Callback(ABC):
|
|||
def on_learn_batch_start(self) -> None:
|
||||
pass
|
||||
|
||||
def on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
|
||||
def on_learn_batch_end(self, experience: Experience) -> None:
|
||||
pass
|
||||
|
|
|
@ -137,7 +137,7 @@ class PerformanceEvaluator(Callback):
|
|||
return
|
||||
self.learn_timer.start()
|
||||
|
||||
def on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
|
||||
def on_learn_batch_end(self, experience: Experience) -> None:
|
||||
if self.disable:
|
||||
return
|
||||
self.learn_timer.end()
|
||||
|
|
|
@ -1,27 +1,26 @@
|
|||
from typing import Dict, List
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from coati.experience_buffer import NaiveExperienceBuffer
|
||||
from coati.experience_maker import Experience, NaiveExperienceMaker
|
||||
from coati.models.base import Actor, Critic, get_base_model
|
||||
from coati.models.base import Actor, Critic, RewardModel, get_base_model
|
||||
from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss
|
||||
from coati.models.utils import calc_action_log_probs
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DistributedSampler
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from tqdm import tqdm
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .base import OnPolicyTrainer
|
||||
from .callbacks import Callback
|
||||
from .strategies import GeminiStrategy, Strategy
|
||||
from .utils import is_rank_0, to_device
|
||||
from .utils import CycledDataLoader, is_rank_0, to_device
|
||||
|
||||
|
||||
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> Dict:
|
||||
unwrapper_model = strategy.unwrap_model(actor)
|
||||
hf_model = get_base_model(unwrapper_model)
|
||||
unwrapped_model = strategy.unwrap_model(actor)
|
||||
hf_model = get_base_model(unwrapped_model)
|
||||
new_kwargs = {**generate_kwargs}
|
||||
# use huggingface models method directly
|
||||
if "prepare_inputs_fn" not in generate_kwargs and hasattr(hf_model, "prepare_inputs_for_generation"):
|
||||
|
@ -41,7 +40,7 @@ class PPOTrainer(OnPolicyTrainer):
|
|||
strategy (Strategy): the strategy to use for training
|
||||
actor (Actor): the actor model in ppo algorithm
|
||||
critic (Critic): the critic model in ppo algorithm
|
||||
reward_model (nn.Module): the reward model in rlhf algorithm to make reward of sentences
|
||||
reward_model (RewardModel): the reward model in rlhf algorithm to make reward of sentences
|
||||
initial_model (Actor): the initial model in rlhf algorithm to generate reference logics to limit the update of actor
|
||||
actor_optim (Optimizer): the optimizer to use for actor model
|
||||
critic_optim (Optimizer): the optimizer to use for critic model
|
||||
|
@ -65,10 +64,11 @@ class PPOTrainer(OnPolicyTrainer):
|
|||
strategy: Strategy,
|
||||
actor: Actor,
|
||||
critic: Critic,
|
||||
reward_model: nn.Module,
|
||||
reward_model: RewardModel,
|
||||
initial_model: Actor,
|
||||
actor_optim: Optimizer,
|
||||
critic_optim: Optimizer,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
kl_coef: float = 0.1,
|
||||
ptx_coef: float = 0.9,
|
||||
train_batch_size: int = 8,
|
||||
|
@ -90,11 +90,11 @@ class PPOTrainer(OnPolicyTrainer):
|
|||
super().__init__(strategy, data_buffer, sample_buffer, dataloader_pin_memory, callbacks)
|
||||
|
||||
self.generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
|
||||
self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
|
||||
self.offload_inference_models = offload_inference_models
|
||||
self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, tokenizer, kl_coef)
|
||||
|
||||
self.actor = actor
|
||||
self.critic = critic
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.actor_loss_fn = PolicyLoss(eps_clip)
|
||||
self.critic_loss_fn = ValueLoss(value_clip)
|
||||
|
@ -104,58 +104,81 @@ class PPOTrainer(OnPolicyTrainer):
|
|||
self.actor_optim = actor_optim
|
||||
self.critic_optim = critic_optim
|
||||
|
||||
self.offload_inference_models = offload_inference_models
|
||||
self.device = get_current_device()
|
||||
|
||||
def _before_fit(
|
||||
self,
|
||||
prompt_dataloader: DataLoader,
|
||||
pretrain_dataloader: DataLoader,
|
||||
log_dir: Optional[str] = None,
|
||||
use_wandb: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
prompt_dataloader (DataLoader): the dataloader to use for prompt data
|
||||
pretrain_dataloader (DataLoader): the dataloader to use for pretrain data
|
||||
"""
|
||||
self.prompt_dataloader = CycledDataLoader(prompt_dataloader)
|
||||
self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader)
|
||||
|
||||
self.writer = None
|
||||
if use_wandb and is_rank_0():
|
||||
assert log_dir is not None, "log_dir must be provided when use_wandb is True"
|
||||
import wandb
|
||||
|
||||
wandb.init(project="Coati-ppo", sync_tensorboard=True)
|
||||
if log_dir is not None and is_rank_0():
|
||||
import os
|
||||
import time
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
log_dir = os.path.join(log_dir, "ppo")
|
||||
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
|
||||
self.writer = SummaryWriter(log_dir=log_dir)
|
||||
|
||||
def _make_experience(self, collect_step: int) -> Experience:
|
||||
prompts = self.prompt_dataloader.next()
|
||||
if self.offload_inference_models:
|
||||
# TODO(ver217): this may be controlled by strategy if they are prepared by strategy
|
||||
self.experience_maker.initial_model.to(self.device)
|
||||
self.experience_maker.reward_model.to(self.device)
|
||||
if isinstance(prompts, Tensor):
|
||||
return self.experience_maker.make_experience(prompts, **self.generate_kwargs)
|
||||
elif isinstance(prompts, dict):
|
||||
return self.experience_maker.make_experience(**prompts, **self.generate_kwargs)
|
||||
else:
|
||||
raise ValueError(f'Unsupported input type "{type(prompts)}"')
|
||||
assert isinstance(prompts, dict), f'Unsupported input type "{type(prompts)}"'
|
||||
return self.experience_maker.make_experience(**prompts, **self.generate_kwargs)
|
||||
|
||||
def _training_step(self, experience: Experience) -> Dict[str, float]:
|
||||
def _training_step(self, experience: Experience):
|
||||
self.actor.train()
|
||||
self.critic.train()
|
||||
# policy loss
|
||||
num_actions = experience.action_mask.size(1)
|
||||
actor_output = self.actor(experience.sequences, attention_mask=experience.attention_mask)
|
||||
action_log_probs = calc_action_log_probs(actor_output, experience.sequences, num_actions)
|
||||
num_actions = experience.action_log_probs.size(1)
|
||||
actor_logits = self.actor(experience.sequences, experience.attention_mask)["logits"]
|
||||
action_log_probs = calc_action_log_probs(actor_logits, experience.sequences, num_actions)
|
||||
actor_loss = self.actor_loss_fn(
|
||||
action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
|
||||
)
|
||||
actor_loss = (1 - self.ptx_coef) * actor_loss
|
||||
self.strategy.backward(actor_loss, self.actor, self.actor_optim)
|
||||
|
||||
# ptx loss
|
||||
if self.ptx_coef != 0:
|
||||
batch = self.pretrain_dataloader.next()
|
||||
batch = to_device(batch, self.device)
|
||||
ptx_log_probs = self.actor(batch["input_ids"], attention_mask=batch["attention_mask"])["logits"]
|
||||
ptx_loss = self.ptx_loss_fn(ptx_log_probs, batch["labels"])
|
||||
actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef)
|
||||
ptx_log_probs = self.actor(batch["input_ids"], batch["attention_mask"])["logits"]
|
||||
ptx_loss = self.ptx_coef * self.ptx_loss_fn(ptx_log_probs, batch["labels"])
|
||||
self.strategy.backward(ptx_loss, self.actor, self.actor_optim)
|
||||
|
||||
self.strategy.backward(actor_loss, self.actor, self.actor_optim)
|
||||
self.strategy.optimizer_step(self.actor_optim)
|
||||
self.actor_optim.zero_grad()
|
||||
|
||||
# value loss
|
||||
values = self.critic(
|
||||
experience.sequences, action_mask=experience.action_mask, attention_mask=experience.attention_mask
|
||||
)
|
||||
critic_loss = self.critic_loss_fn(
|
||||
values, experience.values, experience.reward, action_mask=experience.action_mask
|
||||
)
|
||||
values = self.critic(experience.sequences, attention_mask=experience.attention_mask)
|
||||
critic_loss = self.critic_loss_fn(values, experience.values, experience.reward)
|
||||
critic_loss = critic_loss * self.vf_coef
|
||||
self.strategy.backward(critic_loss, self.critic, self.critic_optim)
|
||||
self.strategy.optimizer_step(self.critic_optim)
|
||||
self.critic_optim.zero_grad()
|
||||
|
||||
return {"reward": experience.reward.mean().item()}
|
||||
|
||||
def _learn(self, update_step: int):
|
||||
if self.offload_inference_models:
|
||||
self.experience_maker.initial_model.to("cpu")
|
||||
|
@ -166,8 +189,8 @@ class PPOTrainer(OnPolicyTrainer):
|
|||
experience = self.data_buffer.sample()
|
||||
self._on_learn_batch_start()
|
||||
experience.to_device(self.device)
|
||||
metrics = self._training_step(experience)
|
||||
self._on_learn_batch_end(metrics, experience)
|
||||
self._training_step(experience)
|
||||
self._on_learn_batch_end(experience)
|
||||
else:
|
||||
if isinstance(self.dataloader.sampler, DistributedSampler):
|
||||
self.dataloader.sampler.set_epoch(update_step)
|
||||
|
@ -175,6 +198,5 @@ class PPOTrainer(OnPolicyTrainer):
|
|||
for experience in pbar:
|
||||
self._on_learn_batch_start()
|
||||
experience.to_device(self.device)
|
||||
metrics = self._training_step(experience)
|
||||
self._on_learn_batch_end(metrics, experience)
|
||||
pbar.set_postfix(metrics)
|
||||
self._training_step(experience)
|
||||
self._on_learn_batch_end(experience)
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
from datetime import datetime
|
||||
from typing import Callable
|
||||
from typing import Callable, Optional
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
import tqdm
|
||||
from torch.optim import Optimizer
|
||||
|
@ -40,10 +38,12 @@ class RewardModelTrainer(SLTrainer):
|
|||
self.loss_fn = loss_fn
|
||||
self.scheduler = lr_scheduler
|
||||
|
||||
self.num_train_step = 0
|
||||
|
||||
def _eval(self, epoch):
|
||||
if self.eval_dataloader is not None:
|
||||
self.model.eval()
|
||||
dist, on, cnt = 0, 0, 0
|
||||
dist, num_correct, num_samples = 0, 0, 0
|
||||
with torch.no_grad():
|
||||
for chosen_ids, c_mask, reject_ids, r_mask in self.eval_dataloader:
|
||||
chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
|
||||
|
@ -52,27 +52,21 @@ class RewardModelTrainer(SLTrainer):
|
|||
r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
|
||||
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
|
||||
reject_reward = self.model(reject_ids, attention_mask=r_mask)
|
||||
for i in range(len(chosen_reward)):
|
||||
cnt += 1
|
||||
if chosen_reward[i] > reject_reward[i]:
|
||||
on += 1
|
||||
num_samples += chosen_ids.size(0)
|
||||
num_correct += (chosen_reward > reject_reward).sum().item()
|
||||
dist += (chosen_reward - reject_reward).mean().item()
|
||||
self.dist = dist / len(self.eval_dataloader)
|
||||
self.acc = on / cnt
|
||||
self.acc = num_correct / num_samples
|
||||
|
||||
if is_rank_0():
|
||||
log = pd.DataFrame(
|
||||
[[(epoch + 1) * len(self.train_dataloader), self.loss.item(), self.dist, self.acc]],
|
||||
columns=["step", "loss", "dist", "acc"],
|
||||
)
|
||||
log.to_csv("log.csv", mode="a", header=False, index=False)
|
||||
if self.writer:
|
||||
self.writer.add_scalar("eval/dist", self.dist, epoch)
|
||||
self.writer.add_scalar("eval/acc", self.acc, epoch)
|
||||
|
||||
def _train(self, epoch):
|
||||
self.model.train()
|
||||
step_bar = tqdm.trange(
|
||||
len(self.train_dataloader), desc="Train step of epoch %d" % epoch, disable=not is_rank_0()
|
||||
len(self.train_dataloader), desc=f"Epoch {epoch + 1}/{self.max_epochs}", disable=not is_rank_0()
|
||||
)
|
||||
cnt = 0
|
||||
for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader:
|
||||
chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
|
||||
c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
|
||||
|
@ -80,26 +74,50 @@ class RewardModelTrainer(SLTrainer):
|
|||
r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
|
||||
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
|
||||
reject_reward = self.model(reject_ids, attention_mask=r_mask)
|
||||
self.loss = self.loss_fn(chosen_reward, reject_reward)
|
||||
self.strategy.backward(self.loss, self.model, self.optimizer)
|
||||
loss = self.loss_fn(chosen_reward, reject_reward)
|
||||
self.strategy.backward(loss, self.model, self.optimizer)
|
||||
self.strategy.optimizer_step(self.optimizer)
|
||||
self.optimizer.zero_grad()
|
||||
cnt += 1
|
||||
if cnt % 100 == 0:
|
||||
if self.writer:
|
||||
self.writer.add_scalar("train/loss", loss.item(), self.num_train_step)
|
||||
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
|
||||
self.writer.add_scalar("train/dist", (chosen_reward - reject_reward).mean().item(), self.num_train_step)
|
||||
self.writer.add_scalar(
|
||||
"train/acc", (chosen_reward > reject_reward).float().mean().item(), self.num_train_step
|
||||
)
|
||||
self.num_train_step += 1
|
||||
if self.num_train_step % 100 == 0:
|
||||
self.scheduler.step()
|
||||
step_bar.update()
|
||||
step_bar.close()
|
||||
|
||||
def _before_fit(self, train_dataloader: DataLoader, valid_dataloader: DataLoader, eval_dataloader: DataLoader):
|
||||
def _before_fit(
|
||||
self,
|
||||
train_dataloader: DataLoader,
|
||||
eval_dataloader: DataLoader,
|
||||
log_dir: Optional[str] = None,
|
||||
use_wandb: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
train_dataloader (DataLoader): the dataloader to use for training
|
||||
valid_dataloader (DataLoader): the dataloader to use for validation
|
||||
eval_dataloader (DataLoader): the dataloader to use for evaluation
|
||||
"""
|
||||
super()._before_fit()
|
||||
self.datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
self.train_dataloader = train_dataloader
|
||||
self.valid_dataloader = valid_dataloader
|
||||
self.eval_dataloader = eval_dataloader
|
||||
|
||||
self.writer = None
|
||||
if use_wandb and is_rank_0():
|
||||
assert log_dir is not None, "log_dir must be provided when use_wandb is True"
|
||||
import wandb
|
||||
|
||||
wandb.init(project="Coati-rm", sync_tensorboard=True)
|
||||
if log_dir is not None and is_rank_0():
|
||||
import os
|
||||
import time
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
log_dir = os.path.join(log_dir, "rm")
|
||||
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
|
||||
self.writer = SummaryWriter(log_dir=log_dir)
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
import time
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import tqdm
|
||||
import wandb
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
|
@ -48,38 +46,34 @@ class SFTTrainer(SLTrainer):
|
|||
self.accumulation_steps = accumulation_steps
|
||||
self.scheduler = lr_scheduler
|
||||
|
||||
self.num_train_step = 0
|
||||
self.num_eval_step = 0
|
||||
|
||||
def _train(self, epoch: int):
|
||||
self.model.train()
|
||||
for batch_id, batch in enumerate(self.train_dataloader):
|
||||
step_bar = tqdm.trange(
|
||||
len(self.train_dataloader) // self.accumulation_steps,
|
||||
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
|
||||
disable=not is_rank_0(),
|
||||
)
|
||||
for i, batch in enumerate(self.train_dataloader):
|
||||
batch = to_device(batch, torch.cuda.current_device())
|
||||
if "attention_mask" in batch:
|
||||
outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
|
||||
else:
|
||||
outputs = self.model(batch["input_ids"], labels=batch["labels"])
|
||||
|
||||
loss = outputs.loss
|
||||
loss = loss / self.accumulation_steps
|
||||
|
||||
self.strategy.backward(loss, self.model, self.optimizer)
|
||||
|
||||
outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
|
||||
loss = outputs.loss / self.accumulation_steps
|
||||
self.total_loss += loss.item()
|
||||
|
||||
self.strategy.backward(loss, self.model, self.optimizer)
|
||||
# gradient accumulation
|
||||
if (batch_id + 1) % self.accumulation_steps == 0:
|
||||
if (i + 1) % self.accumulation_steps == 0:
|
||||
self.strategy.optimizer_step(self.optimizer)
|
||||
self.optimizer.zero_grad()
|
||||
self.scheduler.step()
|
||||
if is_rank_0() and self.use_wandb:
|
||||
wandb.log(
|
||||
{
|
||||
"loss": self.total_loss / self.accumulation_steps,
|
||||
"lr": self.scheduler.get_last_lr()[0],
|
||||
"epoch": epoch,
|
||||
"batch_id": batch_id,
|
||||
}
|
||||
)
|
||||
if self.writer:
|
||||
self.writer.add_scalar("train/loss", self.total_loss, self.num_train_step)
|
||||
self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step)
|
||||
self.num_train_step += 1
|
||||
self.total_loss = 0
|
||||
self.step_bar.update()
|
||||
step_bar.update()
|
||||
step_bar.close()
|
||||
|
||||
def _eval(self, epoch: int):
|
||||
if self.eval_dataloader is not None:
|
||||
|
@ -91,20 +85,21 @@ class SFTTrainer(SLTrainer):
|
|||
outputs = self.model(
|
||||
batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]
|
||||
)
|
||||
loss = outputs.loss
|
||||
|
||||
loss_sum += loss.item()
|
||||
loss_sum += outputs.loss.item()
|
||||
num_seen += batch["input_ids"].size(0)
|
||||
|
||||
loss_mean = loss_sum / num_seen
|
||||
if dist.get_rank() == 0:
|
||||
self.logger.info(f"Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}")
|
||||
if self.writer:
|
||||
self.writer.add_scalar("eval/loss", loss_mean, self.num_eval_step)
|
||||
self.num_eval_step += 1
|
||||
|
||||
def _before_fit(
|
||||
self,
|
||||
train_dataloader: DataLoader,
|
||||
eval_dataloader: Optional[DataLoader] = None,
|
||||
logger: Optional[DistributedLogger] = None,
|
||||
log_dir: Optional[str] = None,
|
||||
use_wandb: bool = False,
|
||||
):
|
||||
"""
|
||||
|
@ -116,15 +111,20 @@ class SFTTrainer(SLTrainer):
|
|||
self.eval_dataloader = eval_dataloader
|
||||
|
||||
self.logger = logger
|
||||
self.use_wandb = use_wandb
|
||||
if use_wandb:
|
||||
wandb.init(project="Coati", name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
||||
wandb.watch(self.model)
|
||||
self.writer = None
|
||||
if use_wandb and is_rank_0():
|
||||
assert log_dir is not None, "log_dir must be provided when use_wandb is True"
|
||||
import wandb
|
||||
|
||||
wandb.init(project="Coati-sft", sync_tensorboard=True)
|
||||
if log_dir is not None and is_rank_0():
|
||||
import os
|
||||
import time
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
log_dir = os.path.join(log_dir, "sft")
|
||||
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
|
||||
self.writer = SummaryWriter(log_dir=log_dir)
|
||||
|
||||
self.total_loss = 0
|
||||
self.no_epoch_bar = True
|
||||
self.step_bar = tqdm.trange(
|
||||
len(self.train_dataloader) // self.accumulation_steps * self.max_epochs,
|
||||
desc=f"steps",
|
||||
disable=not is_rank_0(),
|
||||
)
|
||||
|
|
|
@ -1,17 +1,13 @@
|
|||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin
|
||||
from colossalai.booster.plugin.gemini_plugin import GeminiModel
|
||||
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
|
||||
from colossalai.tensor import ProcessGroup, ShardSpec
|
||||
from colossalai.lazy.lazy_init import LazyInitContext
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import ColoInitContext
|
||||
from colossalai.zero.gemini.gemini_ddp import GeminiDDP
|
||||
|
||||
from .ddp import DDPStrategy
|
||||
|
@ -65,14 +61,11 @@ class LowLevelZeroStrategy(DDPStrategy):
|
|||
assert precision in ("fp32", "fp16"), f'Unsupported precision "{precision}"'
|
||||
|
||||
plugin_initializer = lambda: LowLevelZeroPlugin(
|
||||
# zero_config
|
||||
stage=stage,
|
||||
precision=precision,
|
||||
# zero_optim_config
|
||||
reduce_bucket_size_in_m=reduce_bucket_size,
|
||||
overlap_communication=overlap_communication,
|
||||
cpu_offload=(placement_policy == "cpu"),
|
||||
# optim_config
|
||||
initial_scale=initial_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
|
@ -136,7 +129,7 @@ class GeminiStrategy(DDPStrategy):
|
|||
self,
|
||||
seed: int = 42,
|
||||
shard_init: bool = False, # only for stage 3
|
||||
placement_policy: str = "cuda",
|
||||
placement_policy: str = "auto",
|
||||
pin_memory: bool = True, # only for stage 3
|
||||
force_outputs_fp32: bool = False, # only for stage 3
|
||||
search_range_m: int = 32, # only for stage 3
|
||||
|
@ -153,8 +146,6 @@ class GeminiStrategy(DDPStrategy):
|
|||
max_norm: float = 0.0,
|
||||
norm_type: float = 2.0,
|
||||
) -> None:
|
||||
assert placement_policy in ("cpu", "cuda"), f'Unsupported placement policy "{placement_policy}"'
|
||||
|
||||
# TODO(ver217): support shard_init when using from_pretrained()
|
||||
if shard_init:
|
||||
warnings.warn(
|
||||
|
@ -167,8 +158,7 @@ class GeminiStrategy(DDPStrategy):
|
|||
|
||||
# NOTE: dist should be initialized before calling get_current_device()
|
||||
plugin_initializer = lambda: GeminiPlugin(
|
||||
# gemini_config
|
||||
device=get_current_device(),
|
||||
chunk_init_device=get_current_device(),
|
||||
placement_policy=placement_policy,
|
||||
precision="fp16",
|
||||
pin_memory=pin_memory,
|
||||
|
@ -177,9 +167,7 @@ class GeminiStrategy(DDPStrategy):
|
|||
search_range_m=search_range_m,
|
||||
hidden_dim=hidden_dim,
|
||||
min_chunk_size_m=min_chunk_size_m,
|
||||
# zero_optim_config
|
||||
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
|
||||
# optim_config
|
||||
initial_scale=initial_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
|
@ -200,15 +188,8 @@ class GeminiStrategy(DDPStrategy):
|
|||
colossalai.launch_from_torch({}, seed=self.seed)
|
||||
|
||||
def model_init_context(self):
|
||||
world_size = dist.get_world_size()
|
||||
shard_pg = ProcessGroup(tp_degree=world_size) if self.shard_init else None
|
||||
default_dist_spec = ShardSpec([-1], [world_size]) if self.shard_init else None
|
||||
return ColoInitContext(
|
||||
device=get_current_device(), dtype=torch.half, default_pg=shard_pg, default_dist_spec=default_dist_spec
|
||||
)
|
||||
return LazyInitContext(default_device=get_current_device())
|
||||
|
||||
def unwrap_model(self, model: nn.Module) -> nn.Module:
|
||||
assert isinstance(model, GeminiModel)
|
||||
ddp_model = model.unwrap()
|
||||
assert isinstance(ddp_model, GeminiDDP)
|
||||
return ddp_model.module
|
||||
assert isinstance(model, GeminiDDP)
|
||||
return model.module
|
||||
|
|
|
@ -45,9 +45,17 @@ def eval(args):
|
|||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
actor.eval()
|
||||
tokenizer.padding_side = "left"
|
||||
input_ids = tokenizer.encode(args.input, return_tensors="pt").to(torch.cuda.current_device())
|
||||
outputs = generate(
|
||||
actor, input_ids, max_length=args.max_length, do_sample=True, top_k=50, top_p=0.95, num_return_sequences=1
|
||||
actor,
|
||||
input_ids,
|
||||
tokenizer=tokenizer,
|
||||
max_length=args.max_length,
|
||||
do_sample=True,
|
||||
top_k=50,
|
||||
top_p=0.95,
|
||||
num_return_sequences=1,
|
||||
)
|
||||
output = tokenizer.batch_decode(outputs[0], skip_special_tokens=True)
|
||||
print(f"[Output]: {''.join(output)}")
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
pandas>=1.4.1
|
||||
sentencepiece
|
||||
colossalai==0.3.1
|
||||
colossalai>=0.3.1
|
||||
|
|
|
@ -23,7 +23,7 @@ def main(args):
|
|||
if args.strategy == "ddp":
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == "colossalai_gemini":
|
||||
strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
|
||||
strategy = GeminiStrategy(placement_policy="auto", initial_scale=2**5)
|
||||
elif args.strategy == "colossalai_zero2":
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
|
||||
else:
|
||||
|
@ -65,8 +65,8 @@ def main(args):
|
|||
if args.rm_path is not None:
|
||||
reward_model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
initial_model.to(torch.float16).to(torch.cuda.current_device())
|
||||
reward_model.to(torch.float16).to(torch.cuda.current_device())
|
||||
initial_model.to(torch.bfloat16).to(torch.cuda.current_device())
|
||||
reward_model.to(torch.bfloat16).to(torch.cuda.current_device())
|
||||
|
||||
if args.model == "gpt2":
|
||||
actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
|
||||
|
@ -80,13 +80,13 @@ def main(args):
|
|||
raise ValueError(f'Unsupported actor model "{args.model}"')
|
||||
|
||||
if rm_model_name == "gpt2":
|
||||
critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
||||
critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
|
||||
elif rm_model_name == "bloom":
|
||||
critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
||||
critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
|
||||
elif rm_model_name == "opt":
|
||||
critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
||||
critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
|
||||
elif rm_model_name == "llama":
|
||||
critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
||||
critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
|
||||
else:
|
||||
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
|
||||
|
||||
|
@ -94,17 +94,16 @@ def main(args):
|
|||
critic.load_state_dict(state_dict, strict=False)
|
||||
del state_dict
|
||||
|
||||
if args.strategy != "colossalai_gemini":
|
||||
critic.to(torch.float16).to(torch.cuda.current_device())
|
||||
actor.to(torch.float16).to(torch.cuda.current_device())
|
||||
actor.to(torch.bfloat16).to(torch.cuda.current_device())
|
||||
critic.to(torch.bfloat16).to(torch.cuda.current_device())
|
||||
|
||||
# configure optimizer
|
||||
if args.strategy.startswith("colossalai"):
|
||||
actor_optim = HybridAdam(actor.parameters(), lr=1e-7)
|
||||
critic_optim = HybridAdam(critic.parameters(), lr=1e-7)
|
||||
actor_optim = HybridAdam(actor.parameters(), lr=args.lr)
|
||||
critic_optim = HybridAdam(critic.parameters(), lr=args.lr)
|
||||
else:
|
||||
actor_optim = Adam(actor.parameters(), lr=1e-7)
|
||||
critic_optim = Adam(critic.parameters(), lr=1e-7)
|
||||
actor_optim = Adam(actor.parameters(), lr=args.lr)
|
||||
critic_optim = Adam(critic.parameters(), lr=args.lr)
|
||||
|
||||
# configure tokenizer
|
||||
if args.model == "gpt2":
|
||||
|
@ -126,8 +125,15 @@ def main(args):
|
|||
tokenizer.pad_token = tokenizer.unk_token
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
# NOTE: generate() requires padding_side to be "left"
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
prompt_dataset = PromptDataset(tokenizer=tokenizer, data_path=args.prompt_dataset, max_datasets_size=16384)
|
||||
prompt_dataset = PromptDataset(
|
||||
tokenizer=tokenizer,
|
||||
data_path=args.prompt_dataset,
|
||||
max_datasets_size=args.max_datasets_size,
|
||||
max_length=args.max_input_len,
|
||||
)
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)
|
||||
else:
|
||||
|
@ -137,7 +143,10 @@ def main(args):
|
|||
)
|
||||
|
||||
pretrain_dataset = SupervisedDataset(
|
||||
tokenizer=tokenizer, data_path=args.pretrain_dataset, max_datasets_size=16384, max_length=args.max_input_len
|
||||
tokenizer=tokenizer,
|
||||
data_path=args.pretrain_dataset,
|
||||
max_datasets_size=args.max_datasets_size,
|
||||
max_length=args.max_input_len,
|
||||
)
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True)
|
||||
|
@ -161,6 +170,7 @@ def main(args):
|
|||
initial_model,
|
||||
actor_optim,
|
||||
critic_optim,
|
||||
tokenizer=tokenizer,
|
||||
kl_coef=args.kl_coef,
|
||||
ptx_coef=args.ptx_coef,
|
||||
train_batch_size=args.train_batch_size,
|
||||
|
@ -169,17 +179,17 @@ def main(args):
|
|||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
offload_inference_models=args.strategy != "colossalai_gemini",
|
||||
)
|
||||
|
||||
trainer.fit(
|
||||
prompt_dataloader=prompt_dataloader,
|
||||
pretrain_dataloader=pretrain_dataloader,
|
||||
num_episodes=args.num_episodes,
|
||||
num_collect_steps=args.num_collect_steps,
|
||||
num_update_steps=args.num_update_steps,
|
||||
prompt_dataloader=prompt_dataloader,
|
||||
pretrain_dataloader=pretrain_dataloader,
|
||||
log_dir=args.log_dir,
|
||||
use_wandb=args.use_wandb,
|
||||
)
|
||||
|
||||
# save model checkpoint after fitting
|
||||
|
@ -195,6 +205,7 @@ if __name__ == "__main__":
|
|||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--prompt_dataset", type=str, default=None, help="path to the prompt dataset")
|
||||
parser.add_argument("--pretrain_dataset", type=str, default=None, help="path to the pretrained dataset")
|
||||
parser.add_argument("--max_datasets_size", type=int, default=50000)
|
||||
parser.add_argument(
|
||||
"--strategy",
|
||||
choices=["ddp", "colossalai_gemini", "colossalai_zero2"],
|
||||
|
@ -216,9 +227,12 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--ptx_batch_size", type=int, default=1)
|
||||
parser.add_argument("--experience_batch_size", type=int, default=8)
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument("--lr", type=float, default=1e-7)
|
||||
parser.add_argument("--kl_coef", type=float, default=0.1)
|
||||
parser.add_argument("--ptx_coef", type=float, default=0.9)
|
||||
parser.add_argument("--max_input_len", type=int, default=96)
|
||||
parser.add_argument("--max_seq_len", type=int, default=128)
|
||||
parser.add_argument("--log_dir", default="logs", type=str)
|
||||
parser.add_argument("--use_wandb", default=False, action="store_true")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import argparse
|
||||
from random import randint
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -27,7 +26,7 @@ def train(args):
|
|||
if args.strategy == "ddp":
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == "colossalai_gemini":
|
||||
strategy = GeminiStrategy(placement_policy="cuda")
|
||||
strategy = GeminiStrategy(placement_policy="auto")
|
||||
elif args.strategy == "colossalai_zero2":
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
|
||||
else:
|
||||
|
@ -46,7 +45,7 @@ def train(args):
|
|||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
model.to(torch.float16).to(torch.cuda.current_device())
|
||||
model.to(torch.bfloat16).to(torch.cuda.current_device())
|
||||
|
||||
if args.model_path is not None:
|
||||
state_dict = torch.load(args.model_path)
|
||||
|
@ -75,9 +74,9 @@ def train(args):
|
|||
|
||||
# configure optimizer
|
||||
if args.strategy.startswith("colossalai"):
|
||||
optim = HybridAdam(model.parameters(), lr=5e-6)
|
||||
optim = HybridAdam(model.parameters(), lr=args.lr)
|
||||
else:
|
||||
optim = Adam(model.parameters(), lr=5e-6)
|
||||
optim = Adam(model.parameters(), lr=args.lr)
|
||||
|
||||
# configure loss function
|
||||
if args.loss_fn == "log_sig":
|
||||
|
@ -93,21 +92,14 @@ def train(args):
|
|||
else:
|
||||
data = load_dataset(args.dataset)
|
||||
|
||||
if args.test:
|
||||
train_data = data["train"].select(range(20))
|
||||
eval_data = data["test"].select(range(5))
|
||||
else:
|
||||
train_data = data["train"]
|
||||
eval_data = data["test"]
|
||||
valid_data = data["test"].select((randint(0, len(eval_data) - 1) for _ in range(len(eval_data) // 5)))
|
||||
train_data = data["train"].select(range(min(args.max_datasets_size, len(data["train"]))))
|
||||
eval_data = data["test"].select(range(min(args.max_datasets_size, len(data["test"]))))
|
||||
|
||||
if args.dataset == "Dahoas/rm-static":
|
||||
train_dataset = RmStaticDataset(train_data, tokenizer, args.max_len)
|
||||
valid_dataset = RmStaticDataset(valid_data, tokenizer, args.max_len)
|
||||
eval_dataset = RmStaticDataset(eval_data, tokenizer, args.max_len)
|
||||
elif args.dataset == "Anthropic/hh-rlhf":
|
||||
train_dataset = HhRlhfDataset(train_data, tokenizer, args.max_len)
|
||||
valid_dataset = HhRlhfDataset(valid_data, tokenizer, args.max_len)
|
||||
eval_dataset = HhRlhfDataset(eval_data, tokenizer, args.max_len)
|
||||
else:
|
||||
raise ValueError(f'Unsupported dataset "{args.dataset}"')
|
||||
|
@ -121,14 +113,6 @@ def train(args):
|
|||
rank=dist.get_rank(),
|
||||
num_replicas=dist.get_world_size(),
|
||||
)
|
||||
valid_sampler = DistributedSampler(
|
||||
valid_dataset,
|
||||
shuffle=True,
|
||||
seed=42,
|
||||
drop_last=True,
|
||||
rank=dist.get_rank(),
|
||||
num_replicas=dist.get_world_size(),
|
||||
)
|
||||
eval_sampler = DistributedSampler(
|
||||
eval_dataset,
|
||||
shuffle=True,
|
||||
|
@ -139,7 +123,6 @@ def train(args):
|
|||
)
|
||||
else:
|
||||
train_sampler = None
|
||||
valid_sampler = None
|
||||
eval_sampler = None
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
|
@ -150,14 +133,6 @@ def train(args):
|
|||
pin_memory=True,
|
||||
)
|
||||
|
||||
valid_dataloader = DataLoader(
|
||||
valid_dataset,
|
||||
shuffle=(valid_sampler is None),
|
||||
sampler=valid_sampler,
|
||||
batch_size=args.batch_size,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
eval_dataloader = DataLoader(
|
||||
eval_dataset, shuffle=(eval_sampler is None), sampler=eval_sampler, batch_size=args.batch_size, pin_memory=True
|
||||
)
|
||||
|
@ -176,7 +151,12 @@ def train(args):
|
|||
max_epochs=args.max_epochs,
|
||||
)
|
||||
|
||||
trainer.fit(train_dataloader=train_dataloader, valid_dataloader=valid_dataloader, eval_dataloader=eval_dataloader)
|
||||
trainer.fit(
|
||||
train_dataloader=train_dataloader,
|
||||
eval_dataloader=eval_dataloader,
|
||||
log_dir=args.log_dir,
|
||||
use_wandb=args.use_wandb,
|
||||
)
|
||||
# save model checkpoint after fitting on only rank0
|
||||
strategy.save_model(model, args.save_path, only_rank0=True)
|
||||
# save optimizer checkpoint on all ranks
|
||||
|
@ -200,12 +180,15 @@ if __name__ == "__main__":
|
|||
"--dataset", type=str, choices=["Anthropic/hh-rlhf", "Dahoas/rm-static"], default="Dahoas/rm-static"
|
||||
)
|
||||
parser.add_argument("--subset", type=lambda x: None if x == "None" else x, default=None)
|
||||
parser.add_argument("--max_datasets_size", type=int, default=1000000)
|
||||
parser.add_argument("--save_path", type=str, default="rm_ckpt")
|
||||
parser.add_argument("--max_epochs", type=int, default=1)
|
||||
parser.add_argument("--batch_size", type=int, default=1)
|
||||
parser.add_argument("--max_len", type=int, default=512)
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument("--lr", type=float, default=9e-6)
|
||||
parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"])
|
||||
parser.add_argument("--test", type=bool, default=False)
|
||||
parser.add_argument("--log_dir", default="logs", type=str)
|
||||
parser.add_argument("--use_wandb", default=False, action="store_true")
|
||||
args = parser.parse_args()
|
||||
train(args)
|
||||
|
|
|
@ -16,7 +16,10 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
|||
set_n_least_used_CUDA_VISIBLE_DEVICES 2
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 train_reward_model.py \
|
||||
--model 'bloom' \
|
||||
--pretrain 'gpt2' \
|
||||
--model 'gpt2' \
|
||||
--strategy colossalai_zero2 \
|
||||
--loss_fn 'log_sig' \
|
||||
--dataset 'Anthropic/hh-rlhf'
|
||||
--loss_fn 'log_exp' \
|
||||
--dataset 'Anthropic/hh-rlhf' \
|
||||
--batch_size 16 \
|
||||
--max_epochs 10
|
||||
|
|
|
@ -23,7 +23,6 @@ from transformers.trainer import get_scheduler
|
|||
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.tensor import ColoParameter
|
||||
|
||||
|
||||
def train(args):
|
||||
|
@ -31,7 +30,7 @@ def train(args):
|
|||
if args.strategy == "ddp":
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == "colossalai_gemini":
|
||||
strategy = GeminiStrategy(placement_policy="cuda")
|
||||
strategy = GeminiStrategy(placement_policy="auto")
|
||||
elif args.strategy == "colossalai_zero2":
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
|
||||
elif args.strategy == "colossalai_zero2_cpu":
|
||||
|
@ -57,7 +56,7 @@ def train(args):
|
|||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
model.to(torch.float16).to(torch.cuda.current_device())
|
||||
model.to(torch.bfloat16).to(torch.cuda.current_device())
|
||||
|
||||
# configure tokenizer
|
||||
if args.model == "gpt2":
|
||||
|
@ -84,28 +83,21 @@ def train(args):
|
|||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
if args.model == "llama" and args.strategy == "colossalai_gemini":
|
||||
# this is a hack to deal with the resized embedding
|
||||
# to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatibility
|
||||
for name, param in model.named_parameters():
|
||||
if not isinstance(param, ColoParameter):
|
||||
sub_module_name = ".".join(name.split(".")[:-1])
|
||||
weight_name = name.split(".")[-1]
|
||||
sub_module = model.get_submodule(sub_module_name)
|
||||
setattr(sub_module, weight_name, ColoParameter(param))
|
||||
|
||||
# configure optimizer
|
||||
if args.strategy.startswith("colossalai"):
|
||||
optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0)
|
||||
else:
|
||||
optim = Adam(model.parameters(), lr=args.lr)
|
||||
logger = get_dist_logger()
|
||||
|
||||
# configure dataset
|
||||
if args.dataset == "yizhongw/self_instruct":
|
||||
train_data = load_dataset(args.dataset, "super_natural_instructions", split="train")
|
||||
eval_data = load_dataset(args.dataset, "super_natural_instructions", split="test")
|
||||
|
||||
if args.max_datasets_size is not None:
|
||||
train_data = train_data.select(range(min(args.max_datasets_size, len(train_data))))
|
||||
eval_data = eval_data.select(range(min(args.max_datasets_size, len(eval_data))))
|
||||
|
||||
train_dataset = SFTDataset(train_data, tokenizer, args.max_len)
|
||||
eval_dataset = SFTDataset(eval_data, tokenizer, args.max_len)
|
||||
|
||||
|
@ -176,8 +168,13 @@ def train(args):
|
|||
accumulation_steps=args.accumulation_steps,
|
||||
)
|
||||
|
||||
logger = get_dist_logger()
|
||||
trainer.fit(
|
||||
train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, logger=logger, use_wandb=args.use_wandb
|
||||
train_dataloader=train_dataloader,
|
||||
eval_dataloader=eval_dataloader,
|
||||
logger=logger,
|
||||
log_dir=args.log_dir,
|
||||
use_wandb=args.use_wandb,
|
||||
)
|
||||
|
||||
# save model checkpoint after fitting on only rank0
|
||||
|
@ -207,9 +204,9 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--batch_size", type=int, default=4)
|
||||
parser.add_argument("--max_len", type=int, default=512)
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument("--log_interval", type=int, default=100, help="how many steps to log")
|
||||
parser.add_argument("--lr", type=float, default=5e-6)
|
||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||
parser.add_argument("--log_dir", default="logs", type=str)
|
||||
parser.add_argument("--use_wandb", default=False, action="store_true")
|
||||
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -19,7 +19,6 @@ torchrun --standalone --nproc_per_node=4 train_sft.py \
|
|||
--pretrain "/path/to/LLaMa-7B/" \
|
||||
--model 'llama' \
|
||||
--strategy colossalai_zero2 \
|
||||
--log_interval 10 \
|
||||
--save_path /path/to/Coati-7B \
|
||||
--dataset /path/to/data.json \
|
||||
--batch_size 4 \
|
||||
|
|
|
@ -1,2 +1,2 @@
|
|||
pytest
|
||||
colossalai==0.3.1
|
||||
colossalai>=0.3.1
|
||||
|
|
|
@ -2,7 +2,7 @@ transformers>=4.20.1
|
|||
tqdm
|
||||
datasets
|
||||
loralib
|
||||
colossalai==0.3.1
|
||||
colossalai>=0.3.1
|
||||
torch<2.0.0, >=1.12.1
|
||||
langchain
|
||||
tokenizers
|
||||
|
@ -11,3 +11,4 @@ sse_starlette
|
|||
wandb
|
||||
sentencepiece
|
||||
gpustat
|
||||
tensorboard
|
||||
|
|
|
@ -25,8 +25,8 @@ def get_data(batch_size: int, seq_len: int = 10) -> dict:
|
|||
def train_step(strategy: Strategy, actor: GPTActor, actor_optim: HybridAdam, batch_size: int = 8):
|
||||
data = get_data(batch_size)
|
||||
action_mask = torch.ones_like(data["attention_mask"], dtype=torch.bool)
|
||||
actor_output = actor(data["input_ids"], data["attention_mask"])
|
||||
action_log_probs = calc_action_log_probs(actor_output, data["input_ids"], action_mask.size(1))
|
||||
actor_logits = actor(data["input_ids"], data["attention_mask"])["logits"]
|
||||
action_log_probs = calc_action_log_probs(actor_logits, data["input_ids"], action_mask.size(1))
|
||||
loss = action_log_probs.sum()
|
||||
strategy.backward(loss, actor, actor_optim)
|
||||
strategy.optimizer_step(actor_optim)
|
||||
|
@ -36,7 +36,7 @@ def run_test_checkpoint(strategy_name: str, shard: bool):
|
|||
if strategy_name == "ddp":
|
||||
strategy = DDPStrategy()
|
||||
elif strategy_name == "colossalai_gemini":
|
||||
strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
|
||||
strategy = GeminiStrategy(placement_policy="auto", initial_scale=2**5)
|
||||
elif strategy_name == "colossalai_zero2":
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
|
||||
else:
|
||||
|
|
|
@ -226,7 +226,9 @@ def test_sft_dataset(model: str, dataset_path: Optional[str], max_dataset_size:
|
|||
check_content(input_ids.masked_select(attention_mask), tokenizer, model)
|
||||
assert torch.all(attention_mask)
|
||||
ignore_mask = labels == IGNORE_INDEX
|
||||
check_content(input_ids.masked_select(ignore_mask), tokenizer, model)
|
||||
prompt_mask = torch.logical_and(ignore_mask, attention_mask)
|
||||
check_content(input_ids.masked_select(prompt_mask), tokenizer, model)
|
||||
assert torch.all(input_ids.masked_select(ignore_mask ^ prompt_mask) == tokenizer.pad_token_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import copy
|
||||
import os
|
||||
from copy import deepcopy
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
@ -8,6 +8,7 @@ from coati.experience_buffer import NaiveExperienceBuffer
|
|||
from coati.experience_maker import NaiveExperienceMaker
|
||||
from coati.models.base import RewardModel
|
||||
from coati.models.gpt import GPTActor, GPTCritic
|
||||
from coati.trainer.ppo import _set_default_generate_kwargs
|
||||
from coati.trainer.strategies import DDPStrategy, GeminiStrategy
|
||||
from coati.trainer.strategies.colossalai import LowLevelZeroStrategy
|
||||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
||||
|
@ -42,27 +43,38 @@ def make_and_consume_experience(strategy):
|
|||
elif strategy == "colossalai-zero2":
|
||||
strategy = LowLevelZeroStrategy()
|
||||
elif strategy == "colossalai-gemini":
|
||||
strategy = GeminiStrategy(placement_policy="cuda")
|
||||
strategy = GeminiStrategy(placement_policy="static")
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{strategy}"')
|
||||
|
||||
actor = GPTActor(config=GPT_CONFIG).cuda()
|
||||
critic = GPTCritic(config=GPT_CONFIG).cuda()
|
||||
with strategy.model_init_context():
|
||||
actor = GPTActor(config=GPT_CONFIG).cuda()
|
||||
critic = GPTCritic(config=GPT_CONFIG).cuda()
|
||||
|
||||
initial_model = deepcopy(actor)
|
||||
reward_model = RewardModel(deepcopy(critic.model)).cuda()
|
||||
initial_model = GPTActor(config=GPT_CONFIG).cuda()
|
||||
reward_model = RewardModel(model=copy.deepcopy(critic.model)).cuda()
|
||||
|
||||
experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model)
|
||||
actor, critic, initial_model, reward_model = strategy.prepare(actor, critic, initial_model, reward_model)
|
||||
|
||||
class MockTokenizer:
|
||||
def __init__(self):
|
||||
self.padding_side = "left"
|
||||
self.eos_token_id = 0
|
||||
self.pad_token_id = 0
|
||||
|
||||
tokenizer = MockTokenizer()
|
||||
experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, tokenizer)
|
||||
data_buffer = NaiveExperienceBuffer(SAMPLE_BATCH_SIZE, cpu_offload=False)
|
||||
|
||||
generate_kwargs = dict(do_sample=True, max_length=16)
|
||||
generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
|
||||
|
||||
# experience of all ranks should be the same
|
||||
for _ in range(2):
|
||||
data = get_data(EXPERIENCE_BATCH_SIZE)
|
||||
assert gather_and_equal(data["input_ids"])
|
||||
assert gather_and_equal(data["attention_mask"])
|
||||
experience = experience_maker.make_experience(
|
||||
**data, do_sample=True, max_length=16, eos_token_id=50256, pad_token_id=50256
|
||||
)
|
||||
experience = experience_maker.make_experience(**data, do_sample=True, max_length=16)
|
||||
assert gather_and_equal(experience.sequences)
|
||||
assert gather_and_equal(experience.action_log_probs)
|
||||
assert gather_and_equal(experience.values)
|
||||
|
@ -115,4 +127,4 @@ def test_experience(world_size, strategy):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_experience(2, "colossalai")
|
||||
test_experience(2, "colossalai-zero2")
|
||||
|
|
|
@ -14,7 +14,7 @@ from coati.models.llama import LlamaActor
|
|||
from coati.models.lora import LoraLinear, convert_to_lora_module
|
||||
from coati.models.loss import GPTLMLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
|
||||
from coati.models.opt import OPTRM, OPTActor, OPTCritic
|
||||
from coati.models.utils import calc_action_log_probs, compute_reward, masked_mean
|
||||
from coati.models.utils import calc_action_log_probs, masked_mean
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
|
@ -27,7 +27,6 @@ from coati.models.utils import calc_action_log_probs, compute_reward, masked_mea
|
|||
# HACK: skip llama due to long execution time
|
||||
# lambda: LlamaActor(),
|
||||
lambda: OPTActor(),
|
||||
# lambda: ChatGLMActor(),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -43,9 +42,16 @@ from coati.models.utils import calc_action_log_probs, compute_reward, masked_mea
|
|||
],
|
||||
)
|
||||
def test_generation(actor_maker: Callable[[], Actor], batch_size: int, seq_len: int, generate_kwargs: Dict[str, Any]):
|
||||
class MockTokenizer:
|
||||
def __init__(self):
|
||||
self.padding_side = "left"
|
||||
self.eos_token_id = 0
|
||||
self.pad_token_id = 0
|
||||
|
||||
actor = actor_maker()
|
||||
input_ids = torch.randint(0, 100, (batch_size, seq_len)).cuda()
|
||||
sequences = generate(actor.cuda(), input_ids, **generate_kwargs)
|
||||
tokenizer = MockTokenizer()
|
||||
sequences = generate(actor.cuda(), input_ids, tokenizer, **generate_kwargs)
|
||||
assert sequences.shape == (batch_size, generate_kwargs["max_length"])
|
||||
|
||||
|
||||
|
@ -55,24 +61,12 @@ def test_utils():
|
|||
assert fn_output.dim() == 0
|
||||
assert torch.allclose(fn_output, torch.tensor(1.0))
|
||||
|
||||
batch_size = 4
|
||||
num_labels = 10
|
||||
fn_input = {
|
||||
"r": torch.ones((batch_size,)),
|
||||
"kl_coef": 1.0,
|
||||
"log_probs": torch.randn((batch_size, num_labels)),
|
||||
"log_probs_base": torch.randn((batch_size, num_labels)),
|
||||
"action_mask": torch.randint(0, 2, (batch_size, num_labels)),
|
||||
}
|
||||
fn_output = compute_reward(**fn_input)
|
||||
assert fn_output.shape == (batch_size,)
|
||||
|
||||
batch_size = 4
|
||||
seq_len = 32
|
||||
num_labels = 10
|
||||
num_actions = 2
|
||||
fn_input = {
|
||||
"output": {"logits": torch.randn((batch_size, seq_len, num_labels))},
|
||||
"logits": torch.randn((batch_size, seq_len, num_labels)),
|
||||
"sequences": torch.randint(0, num_labels, (batch_size, seq_len)),
|
||||
"num_actions": num_actions,
|
||||
}
|
||||
|
@ -135,7 +129,6 @@ def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]], b
|
|||
}
|
||||
critic_input = {
|
||||
"sequences": torch.randint(0, 100, (batch_size, seq_len)),
|
||||
"action_mask": torch.randint(0, 2, (batch_size, seq_len)),
|
||||
"attention_mask": torch.randint(0, 2, (batch_size, seq_len)),
|
||||
}
|
||||
rm_input = {
|
||||
|
|
|
@ -24,8 +24,8 @@ if [ -z "$SFT_DATASET" ]; then
|
|||
exit 1
|
||||
fi
|
||||
|
||||
if [ -z "$PROMPT_PATH" ]; then
|
||||
echo "Please set \$PROMPT_PATH to the path to prompts csv."
|
||||
if [ -z "$PROMPT_DATASET" ]; then
|
||||
echo "Please set \$PROMPT_DATASET to the path to prompts csv."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
@ -74,11 +74,15 @@ echo "[Test]: testing sft ..."
|
|||
# FIXME: This is a hack to skip tests that are not working
|
||||
# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
|
||||
# - llama-*: These tests can be passed locally, skipped for long execution time
|
||||
# - *-gemini: Gemini plugin does not support `from_pretrained` yet
|
||||
SKIPPED_TESTS=(
|
||||
"gpt2-ddp"
|
||||
"llama-ddp"
|
||||
"llama-colossalai_gemini"
|
||||
"llama-colossalai_zero2"
|
||||
"gpt2-colossalai_gemini"
|
||||
"opt-colossalai_gemini"
|
||||
"bloom-colossalai_gemini"
|
||||
)
|
||||
|
||||
GRAD_CKPTS=('' '--grad_checkpoint')
|
||||
|
@ -105,7 +109,7 @@ for lora_rank in '0' '4'; do
|
|||
$pretrain_model --tokenizer $MODELS_DIR/$model \
|
||||
--model $model --strategy $strategy --lora_rank $lora_rank $grad_ckpt \
|
||||
--dataset $SFT_DATASET --max_datasets_size 8 \
|
||||
--max_epochs 1 --batch_size 1 --accumulation_steps 1 \
|
||||
--max_epochs 1 --batch_size 1 --accumulation_steps 1 --lr 1e-8 \
|
||||
--save_path $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank}
|
||||
passed=$?
|
||||
if [ $passed -eq 0 ]; then
|
||||
|
@ -125,11 +129,15 @@ echo "[Test]: testing reward model ..."
|
|||
# FIXME: This is a hack to skip tests that are not working
|
||||
# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
|
||||
# - llama-*: These tests can be passed locally, skipped for long execution time
|
||||
# - *-gemini: Gemini plugin does not support `from_pretrained` yet
|
||||
SKIPPED_TESTS=(
|
||||
"gpt2-ddp"
|
||||
"llama-ddp"
|
||||
"llama-colossalai_gemini"
|
||||
"llama-colossalai_zero2"
|
||||
"gpt2-colossalai_gemini"
|
||||
"opt-colossalai_gemini"
|
||||
"bloom-colossalai_gemini"
|
||||
)
|
||||
|
||||
LOSS_FNS=('log_sig' 'log_exp')
|
||||
|
@ -157,8 +165,9 @@ for lora_rank in '0' '4'; do
|
|||
echo "[Test]: $model-$strategy-$lora_rank, attempt $i"
|
||||
torchrun --standalone --nproc_per_node=4 $EXAMPLES_DIR/train_reward_model.py \
|
||||
$pretrain_model --tokenizer $MODELS_DIR/$model \
|
||||
--model $model --strategy $strategy --lora_rank $lora_rank --loss_fn $loss_fn \
|
||||
--dataset $dataset --subset $subset --test True --batch_size 1 \
|
||||
--dataset $dataset --subset $subset --max_datasets_size 8 \
|
||||
--model $model --strategy $strategy --lora_rank $lora_rank \
|
||||
--loss_fn $loss_fn --batch_size 1 --lr 1e-8 \
|
||||
--save_path $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt
|
||||
passed=$?
|
||||
if [ $passed -eq 0 ]; then
|
||||
|
@ -178,11 +187,15 @@ echo "[Test]: testing RLHF ..."
|
|||
# FIXME: This is a hack to skip tests that are not working
|
||||
# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
|
||||
# - llama-*: These tests can be passed locally, skipped for long execution time
|
||||
# - *-gemini: Gemini plugin does not support `from_pretrained` yet
|
||||
SKIPPED_TESTS=(
|
||||
"gpt2-ddp"
|
||||
"llama-ddp"
|
||||
"llama-colossalai_gemini"
|
||||
"llama-colossalai_zero2"
|
||||
"gpt2-colossalai_gemini"
|
||||
"opt-colossalai_gemini"
|
||||
"bloom-colossalai_gemini"
|
||||
)
|
||||
|
||||
for model in ${MODELS[@]}; do
|
||||
|
@ -204,9 +217,9 @@ for model in ${MODELS[@]}; do
|
|||
for i in $(seq $NUM_RETRY); do
|
||||
echo "[Test]: $model-$strategy-$lora_rank, attempt $i"
|
||||
torchrun --standalone --nproc_per_node=4 $EXAMPLES_DIR/train_prompts.py \
|
||||
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
|
||||
--prompt_dataset $PROMPT_DATASET --pretrain_dataset $PRETRAIN_DATASET --max_datasets_size 32 \
|
||||
--strategy $strategy --model $model --tokenizer $MODELS_DIR/$model \
|
||||
--num_episodes 1 --num_collect_steps 1 --num_update_steps 1 \
|
||||
--num_episodes 1 --num_collect_steps 1 --num_update_steps 1 --lr 1e-8 \
|
||||
--experience_batch_size 2 --train_batch_size 1 --lora_rank $lora_rank \
|
||||
--pretrain $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank} \
|
||||
$rm_pretrain_model --rm_path $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt \
|
||||
|
|
|
@ -3,6 +3,7 @@ import time
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
from model_zoo import GPTLMLoss, get_gpt2_components
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
import colossalai
|
||||
|
@ -13,7 +14,6 @@ from colossalai.fx.profiler import parameter_size
|
|||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.testing import spawn
|
||||
from colossalai.utils import get_current_device
|
||||
from model_zoo import GPTLMLoss, get_gpt2_components
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
|
|
@ -3,6 +3,7 @@ import time
|
|||
from functools import partial
|
||||
|
||||
import torch
|
||||
from model_zoo import model_builder
|
||||
from torch import nn
|
||||
|
||||
from colossalai.fx import ColoTracer
|
||||
|
@ -12,7 +13,6 @@ from colossalai.legacy.pipeline.middleware.adaptor import get_fx_topology
|
|||
from colossalai.legacy.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine
|
||||
from colossalai.legacy.pipeline.rpc.utils import rpc_run
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from model_zoo import model_builder
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
|
Loading…
Reference in New Issue