[chat]: update rm, add wandb and fix bugs (#4471)

* feat: modify forward fn of critic and reward model

* feat: modify calc_action_log_probs

* to: add wandb in sft and rm trainer

* feat: update train_sft

* feat: update train_rm

* style: modify type annotation and add warning

* feat: pass tokenizer to ppo trainer

* to: modify trainer base and maker base

* feat: add wandb in ppo trainer

* feat: pass tokenizer to generate

* test: update generate fn tests

* test: update train tests

* fix: remove action_mask

* feat: remove unused code

* fix: fix wrong ignore_index

* fix: fix mock tokenizer

* chore: update requirements

* revert: modify make_experience

* fix: fix inference

* fix: add padding side

* style: modify _on_learn_batch_end

* test: use mock tokenizer

* fix: use bf16 to avoid overflow

* fix: fix workflow

* [chat] fix gemini strategy

* [chat] fix

* sync: update colossalai strategy

* fix: fix args and model dtype

* fix: fix checkpoint test

* fix: fix requirements

* fix: fix missing import and wrong arg

* fix: temporarily skip gemini test in stage 3

* style: apply pre-commit

* fix: temporarily skip gemini test in stage 1&2

---------

Co-authored-by: Mingyan Jiang <1829166702@qq.com>
pull/4766/head
Wenhao Chen 2023-09-20 15:53:58 +08:00 committed by GitHub
parent 07c2e3d09c
commit 7b9b86441f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
36 changed files with 382 additions and 332 deletions

View File

@ -49,5 +49,5 @@ jobs:
NCCL_SHM_DISABLE: 1 NCCL_SHM_DISABLE: 1
MAX_JOBS: 8 MAX_JOBS: 8
SFT_DATASET: /data/scratch/github_actions/chat/data.json SFT_DATASET: /data/scratch/github_actions/chat/data.json
PROMPT_PATH: /data/scratch/github_actions/chat/prompts_en.jsonl PROMPT_DATASET: /data/scratch/github_actions/chat/prompts_en.jsonl
PRETRAIN_DATASET: /data/scratch/github_actions/chat/alpaca_data.json PRETRAIN_DATASET: /data/scratch/github_actions/chat/alpaca_data.json

View File

@ -138,6 +138,7 @@ def main(args):
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
(actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim)) (actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim))
@ -154,6 +155,7 @@ def main(args):
initial_model, initial_model,
actor_optim, actor_optim,
critic_optim, critic_optim,
tokenizer=tokenizer,
ptx_coef=0, ptx_coef=0,
train_batch_size=args.train_batch_size, train_batch_size=args.train_batch_size,
offload_inference_models=args.offload_inference_models, offload_inference_models=args.offload_inference_models,
@ -162,8 +164,6 @@ def main(args):
temperature=1.0, temperature=1.0,
top_k=50, top_k=50,
use_cache=True, use_cache=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
callbacks=[performance_evaluator], callbacks=[performance_evaluator],
) )

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import copy import copy
from typing import Dict, Sequence, Tuple from typing import Dict, Optional, Sequence, Tuple
import torch import torch
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
@ -57,6 +57,7 @@ def _preprocess(
sources, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" sources, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
) )
assert sequences_token["attention_mask"].dim() == 2, "seq2seq model should be preprocessed differently"
labels = copy.deepcopy(sequences_token["input_ids"]) labels = copy.deepcopy(sequences_token["input_ids"])
for i in range(labels.shape[0]): for i in range(labels.shape[0]):
source_len = sources_token["attention_mask"][i].sum().item() source_len = sources_token["attention_mask"][i].sum().item()
@ -64,9 +65,10 @@ def _preprocess(
if tokenizer.padding_side == "right": if tokenizer.padding_side == "right":
# |prompt|completion|eos|pad| # |prompt|completion|eos|pad|
labels[i][:source_len] = IGNORE_INDEX labels[i][:source_len] = IGNORE_INDEX
labels[i][-pad_len:] = IGNORE_INDEX
elif tokenizer.padding_side == "left": elif tokenizer.padding_side == "left":
# |pad|prompt|completion|eos| # |pad|prompt|completion|eos|
labels[i][pad_len : pad_len + source_len] = IGNORE_INDEX labels[i][: pad_len + source_len] = IGNORE_INDEX
else: else:
raise RuntimeError() raise RuntimeError()
@ -126,6 +128,8 @@ class SFTDataset(Dataset):
sources = [data["prompt"] for data in dataset] sources = [data["prompt"] for data in dataset]
targets = [data["completion"] + tokenizer.eos_token for data in tqdm(dataset, disable=not is_rank_0())] targets = [data["completion"] + tokenizer.eos_token for data in tqdm(dataset, disable=not is_rank_0())]
logger.info("Tokenizing inputs... This may take some time...")
if isinstance(tokenizer, ChatGLMTokenizer): if isinstance(tokenizer, ChatGLMTokenizer):
self.input_ids, self.labels, self.attention_mask = _preprocess_chatglm( self.input_ids, self.labels, self.attention_mask = _preprocess_chatglm(
sources, targets, tokenizer, max_length sources, targets, tokenizer, max_length
@ -133,6 +137,8 @@ class SFTDataset(Dataset):
else: else:
self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length) self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length)
logger.info("Loaded dataset.")
def __len__(self): def __len__(self):
length = self.input_ids.shape[0] length = self.input_ids.shape[0]
return length return length
@ -148,7 +154,11 @@ class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning.""" """Dataset for supervised fine-tuning."""
def __init__( def __init__(
self, data_path: str, tokenizer: PreTrainedTokenizer, max_datasets_size: int = None, max_length: int = 512 self,
data_path: str,
tokenizer: PreTrainedTokenizer,
max_datasets_size: Optional[int] = None,
max_length: int = 512,
): ):
super().__init__() super().__init__()
logger.info("Loading data...") logger.info("Loading data...")
@ -175,6 +185,8 @@ class SupervisedDataset(Dataset):
else: else:
self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length) self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length)
logger.info("Loaded dataset.")
def __len__(self): def __len__(self):
length = self.input_ids.shape[0] length = self.input_ids.shape[0]
return length return length

View File

@ -1,4 +1,5 @@
import random import random
import warnings
from typing import List from typing import List
import torch import torch
@ -30,9 +31,11 @@ class NaiveExperienceBuffer(ExperienceBuffer):
experience.to_device(torch.device("cpu")) experience.to_device(torch.device("cpu"))
items = split_experience_batch(experience) items = split_experience_batch(experience)
self.items.extend(items) self.items.extend(items)
if self.limit > 0: if self.limit > 0:
samples_to_remove = len(self.items) - self.limit samples_to_remove = len(self.items) - self.limit
if samples_to_remove > 0: if samples_to_remove > 0:
warnings.warn(f"Experience buffer is full. Removing {samples_to_remove} samples.")
self.items = self.items[samples_to_remove:] self.items = self.items[samples_to_remove:]
def clear(self) -> None: def clear(self) -> None:

View File

@ -3,8 +3,7 @@ from dataclasses import dataclass
from typing import Optional from typing import Optional
import torch import torch
import torch.nn as nn from coati.models.base import Actor, Critic, RewardModel
from coati.models.base import Actor
@dataclass @dataclass
@ -59,16 +58,13 @@ class Experience:
class ExperienceMaker(ABC): class ExperienceMaker(ABC):
def __init__( def __init__(self, actor: Actor, critic: Critic, reward_model: RewardModel, initial_model: Actor) -> None:
self, actor: Actor, critic: nn.Module, reward_model: nn.Module, initial_model: Actor, kl_coef: float = 0.1
) -> None:
super().__init__() super().__init__()
self.actor = actor self.actor = actor
self.critic = critic self.critic = critic
self.reward_model = reward_model self.reward_model = reward_model
self.initial_model = initial_model self.initial_model = initial_model
self.kl_coef = kl_coef
@abstractmethod @abstractmethod
def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experience: def make_experience(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **generate_kwargs) -> Experience:
pass pass

View File

@ -1,7 +1,9 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from coati.models.base import Actor, Critic, RewardModel
from coati.models.generation import generate from coati.models.generation import generate
from coati.models.utils import calc_action_log_probs, compute_reward from coati.models.utils import calc_action_log_probs, compute_reward
from transformers import PreTrainedTokenizer
from .base import Experience, ExperienceMaker from .base import Experience, ExperienceMaker
@ -11,6 +13,19 @@ class NaiveExperienceMaker(ExperienceMaker):
Naive experience maker. Naive experience maker.
""" """
def __init__(
self,
actor: Actor,
critic: Critic,
reward_model: RewardModel,
initial_model: Actor,
tokenizer: PreTrainedTokenizer,
kl_coef: float = 0.1,
) -> None:
super().__init__(actor, critic, reward_model, initial_model)
self.tokenizer = tokenizer
self.kl_coef = kl_coef
@torch.no_grad() @torch.no_grad()
def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experience: def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experience:
self.actor.eval() self.actor.eval()
@ -19,16 +34,16 @@ class NaiveExperienceMaker(ExperienceMaker):
self.reward_model.eval() self.reward_model.eval()
# generate sequences # generate sequences
sequences = generate(self.actor, input_ids, **generate_kwargs) sequences = generate(self.actor, input_ids, self.tokenizer, **generate_kwargs)
# calculate auxiliary tensors # calculate auxiliary tensors
attention_mask = None attention_mask = None
pad_token_id = generate_kwargs.get("pad_token_id", None) pad_token_id = self.tokenizer.pad_token_id
if pad_token_id is not None: if pad_token_id is not None:
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device) attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
input_len = input_ids.size(1) input_len = input_ids.size(1)
eos_token_id = generate_kwargs.get("eos_token_id", None) eos_token_id = self.tokenizer.eos_token_id
if eos_token_id is None: if eos_token_id is None:
action_mask = torch.ones_like(sequences, dtype=torch.bool) action_mask = torch.ones_like(sequences, dtype=torch.bool)
else: else:
@ -40,11 +55,11 @@ class NaiveExperienceMaker(ExperienceMaker):
action_mask = action_mask[:, -(sequences.size(1) - input_len) :] action_mask = action_mask[:, -(sequences.size(1) - input_len) :]
num_actions = action_mask.size(1) num_actions = action_mask.size(1)
actor_output = self.actor(sequences, attention_mask) actor_output = self.actor(sequences, attention_mask)["logits"]
action_log_probs = calc_action_log_probs(actor_output, sequences, num_actions) action_log_probs = calc_action_log_probs(actor_output, sequences, num_actions)
base_model_output = self.initial_model(sequences, attention_mask) base_model_output = self.initial_model(sequences, attention_mask)["logits"]
base_action_log_probs = calc_action_log_probs(base_model_output, sequences, num_actions) base_action_log_probs = calc_action_log_probs(base_model_output, sequences, num_actions)
value = self.critic(sequences, action_mask, attention_mask) value = self.critic(sequences, attention_mask)
r = self.reward_model(sequences, attention_mask) r = self.reward_model(sequences, attention_mask)
reward = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask) reward = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask)

View File

@ -25,7 +25,7 @@ class Actor(LoRAModule):
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
**model_kwargs, # HACK: `generate` method may pass more kwargs **model_kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
"""Returns model output.""" """Returns model output."""
output = self.model(input_ids, attention_mask=attention_mask, **model_kwargs) output = self.model(input_ids, attention_mask=attention_mask, **model_kwargs)

View File

@ -1,10 +1,7 @@
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from ..lora import LoRAModule from ..lora import LoRAModule
from ..utils import masked_mean
class Critic(LoRAModule): class Critic(LoRAModule):
@ -19,37 +16,19 @@ class Critic(LoRAModule):
""" """
def __init__( def __init__(
self, self, model: nn.Module, value_head: nn.Module, lora_rank: int = 0, lora_train_bias: str = "none"
model: nn.Module,
value_head: nn.Module,
lora_rank: int = 0,
lora_train_bias: str = "none",
use_action_mask: bool = False,
) -> None: ) -> None:
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias) super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
self.model = model self.model = model
self.value_head = value_head self.value_head = value_head
self.use_action_mask = use_action_mask
self.convert_to_lora() self.convert_to_lora()
def forward( def forward(self, sequences: torch.LongTensor, attention_mask: torch.Tensor) -> torch.Tensor:
self,
sequences: torch.LongTensor,
action_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
outputs = self.model(sequences, attention_mask=attention_mask) outputs = self.model(sequences, attention_mask=attention_mask)
last_hidden_states = outputs["last_hidden_state"] last_hidden_states = outputs["last_hidden_state"]
sequence_lengths = torch.max(attention_mask * torch.arange(sequences.size(1), device=sequences.device), dim=1)[
values = self.value_head(last_hidden_states).squeeze(-1) 0
]
if action_mask is not None and self.use_action_mask: sequence_hidden_states = last_hidden_states[torch.arange(last_hidden_states.size(0)), sequence_lengths]
num_actions = action_mask.size(1) values = self.value_head(sequence_hidden_states).squeeze(1) # ensure shape is (B, )
prompt_mask = attention_mask[:, :-num_actions] return values
values = values[:, :-num_actions]
value = masked_mean(values, prompt_mask, dim=1)
return value
values = values[:, :-1]
value = values.mean(dim=1)
return value

View File

@ -35,9 +35,12 @@ class RewardModel(LoRAModule):
else: else:
self.value_head = nn.Linear(model.config.n_embd, 1) self.value_head = nn.Linear(model.config.n_embd, 1)
def forward(self, sequences: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: def forward(self, sequences: torch.LongTensor, attention_mask: torch.Tensor) -> torch.Tensor:
outputs = self.model(sequences, attention_mask=attention_mask) outputs = self.model(sequences, attention_mask=attention_mask)
last_hidden_states = outputs["last_hidden_state"] last_hidden_states = outputs["last_hidden_state"]
values = self.value_head(last_hidden_states)[:, :-1] sequence_lengths = torch.max(attention_mask * torch.arange(sequences.size(1), device=sequences.device), dim=1)[
value = values.mean(dim=1).squeeze(1) # ensure shape is (B) 0
return value ]
sequence_hidden_states = last_hidden_states[torch.arange(last_hidden_states.size(0)), sequence_lengths]
values = self.value_head(sequence_hidden_states).squeeze(1) # ensure shape is (B, )
return values

View File

@ -2,6 +2,7 @@ from typing import Any, Callable, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from transformers import PreTrainedTokenizer
from .base import Actor from .base import Actor
@ -63,8 +64,8 @@ def _sample(
) )
outputs = model(**model_inputs) outputs = model(**model_inputs)
# NOTE: this is correct only in left padding mode
next_token_logits = outputs["logits"][:, -1, :] next_token_logits = outputs["logits"][:, -1, :]
# pre-process distribution
next_token_logits = logits_processor(input_ids, next_token_logits) next_token_logits = logits_processor(input_ids, next_token_logits)
# sample # sample
probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float) probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float)
@ -72,8 +73,7 @@ def _sample(
# finished sentences should have their next token be a padding token # finished sentences should have their next token be a padding token
if eos_token_id is not None: if eos_token_id is not None:
if pad_token_id is None: assert pad_token_id is not None, "If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# update generated ids, model inputs for next step # update generated ids, model inputs for next step
@ -96,12 +96,11 @@ def _sample(
def generate( def generate(
model: Actor, model: Actor,
input_ids: torch.Tensor, input_ids: torch.Tensor,
tokenizer: PreTrainedTokenizer,
max_length: int, max_length: int,
num_beams: int = 1, num_beams: int = 1,
do_sample: bool = True, do_sample: bool = True,
early_stopping: bool = False, early_stopping: bool = False,
eos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
top_k: Optional[int] = None, top_k: Optional[int] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
temperature: Optional[float] = None, temperature: Optional[float] = None,
@ -118,14 +117,13 @@ def generate(
num_beams (int, optional): number of beams. Defaults to 1. num_beams (int, optional): number of beams. Defaults to 1.
do_sample (bool, optional): whether to do sample. Defaults to True. do_sample (bool, optional): whether to do sample. Defaults to True.
early_stopping (bool, optional): if True, the sequence length may be smaller than max_length due to finding eos. Defaults to False. early_stopping (bool, optional): if True, the sequence length may be smaller than max_length due to finding eos. Defaults to False.
eos_token_id (Optional[int], optional): end of sequence token id. Defaults to None.
pad_token_id (Optional[int], optional): pad token id. Defaults to None.
top_k (Optional[int], optional): the number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to None. top_k (Optional[int], optional): the number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to None.
top_p (Optional[float], optional): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. Defaults to None. top_p (Optional[float], optional): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. Defaults to None.
temperature (Optional[float], optional): The value used to module the next token probabilities. Defaults to None. temperature (Optional[float], optional): The value used to module the next token probabilities. Defaults to None.
prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None. prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None.
update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None. update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None.
""" """
assert tokenizer.padding_side == "left", "Current generation only supports left padding."
is_greedy_gen_mode = (num_beams == 1) and do_sample is False is_greedy_gen_mode = (num_beams == 1) and do_sample is False
is_sample_gen_mode = (num_beams == 1) and do_sample is True is_sample_gen_mode = (num_beams == 1) and do_sample is True
is_beam_gen_mode = (num_beams > 1) and do_sample is False is_beam_gen_mode = (num_beams > 1) and do_sample is False
@ -139,8 +137,8 @@ def generate(
input_ids, input_ids,
max_length, max_length,
early_stopping=early_stopping, early_stopping=early_stopping,
eos_token_id=eos_token_id, eos_token_id=tokenizer.eos_token_id,
pad_token_id=pad_token_id, pad_token_id=tokenizer.pad_token_id,
top_k=top_k, top_k=top_k,
top_p=top_p, top_p=top_p,
temperature=temperature, temperature=temperature,

View File

@ -13,6 +13,7 @@ class GPTLMLoss(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
# NOTE: default ignore_index is -100, which is equal to IGNORE_INDEX in sft_dataset.py
self.loss = nn.CrossEntropyLoss() self.loss = nn.CrossEntropyLoss()
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:

View File

@ -46,18 +46,17 @@ def _log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.
return log_probs_labels.squeeze(-1) return log_probs_labels.squeeze(-1)
def calc_action_log_probs(output: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor: def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor:
"""Calculate action log probs. """Calculate action log probs.
Args: Args:
output (torch.Tensor): Output tensor of Actor.forward. output (torch.Tensor): Output tensor of Actor.forward.logits.
sequences (torch.LongTensor): Input sequences. sequences (torch.LongTensor): Input sequences.
num_actions (int): Number of actions. num_actions (int): Number of actions.
Returns: Returns:
torch.Tensor: Action log probs. torch.Tensor: Action log probs.
""" """
logits = output["logits"]
log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
return log_probs[:, -num_actions:] return log_probs[:, -num_actions:]

View File

@ -41,13 +41,13 @@ def get_actor_from_args(model: str, pretrained: str = None, config=None, lora_ra
def get_critic_from_args(model: str, pretrained: str = None, config=None, lora_rank=0): def get_critic_from_args(model: str, pretrained: str = None, config=None, lora_rank=0):
if model == "gpt2": if model == "gpt2":
critic = GPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True) critic = GPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config)
elif model == "bloom": elif model == "bloom":
critic = BLOOMCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True) critic = BLOOMCritic(pretrained=pretrained, lora_rank=lora_rank, config=config)
elif model == "opt": elif model == "opt":
critic = OPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True) critic = OPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config)
elif model == "llama": elif model == "llama":
critic = LlamaCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True) critic = LlamaCritic(pretrained=pretrained, lora_rank=lora_rank, config=config)
else: else:
raise ValueError(f'Unsupported reward model "{model}"') raise ValueError(f'Unsupported reward model "{model}"')
return critic return critic

View File

@ -7,11 +7,10 @@ import tqdm
from coati.experience_buffer import NaiveExperienceBuffer from coati.experience_buffer import NaiveExperienceBuffer
from coati.experience_maker import Experience from coati.experience_maker import Experience
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils.data import DataLoader
from .callbacks import Callback from .callbacks import Callback
from .strategies import Strategy from .strategies import Strategy
from .utils import CycledDataLoader, is_rank_0 from .utils import is_rank_0
class SLTrainer(ABC): class SLTrainer(ABC):
@ -47,11 +46,11 @@ class SLTrainer(ABC):
raise NotImplementedError() raise NotImplementedError()
def _before_fit(self): def _before_fit(self):
self.no_epoch_bar = False raise NotImplementedError()
def fit(self, *args, **kwargs): def fit(self, *args, **kwargs):
self._before_fit(*args, **kwargs) self._before_fit(*args, **kwargs)
for epoch in tqdm.trange(self.max_epochs, desc="Epochs", disable=not is_rank_0() or self.no_epoch_bar): for epoch in tqdm.trange(self.max_epochs, desc="Epochs", disable=not is_rank_0()):
self._train(epoch) self._train(epoch)
self._eval(epoch) self._eval(epoch)
@ -123,9 +122,9 @@ class OnPolicyTrainer(ABC):
for callback in self.callbacks: for callback in self.callbacks:
callback.on_learn_batch_start() callback.on_learn_batch_start()
def _on_learn_batch_end(self, metrics: dict, experience: Experience) -> None: def _on_learn_batch_end(self, experience: Experience) -> None:
for callback in self.callbacks: for callback in self.callbacks:
callback.on_learn_batch_end(metrics, experience) callback.on_learn_batch_end(experience)
@abstractmethod @abstractmethod
def _make_experience(self, collect_step: int): def _make_experience(self, collect_step: int):
@ -153,27 +152,26 @@ class OnPolicyTrainer(ABC):
self._learn(update_step) self._learn(update_step)
self._on_learn_epoch_end(update_step) self._on_learn_epoch_end(update_step)
def _before_fit(self, *args, **kwargs):
raise NotImplementedError()
def fit( def fit(
self, self,
prompt_dataloader: DataLoader,
pretrain_dataloader: DataLoader,
num_episodes: int, num_episodes: int,
num_collect_steps: int, num_collect_steps: int,
num_update_steps: int, num_update_steps: int,
*args,
**kwargs,
): ):
""" """
The main training loop of on-policy rl trainers. The main training loop of on-policy rl trainers.
Args: Args:
prompt_dataloader (DataLoader): the dataloader to use for prompt data
pretrain_dataloader (DataLoader): the dataloader to use for pretrain data
num_episodes (int): the number of episodes to train num_episodes (int): the number of episodes to train
num_collect_steps (int): the number of collect steps per episode num_collect_steps (int): the number of collect steps per episode
num_update_steps (int): the number of update steps per episode num_update_steps (int): the number of update steps per episode
""" """
self.prompt_dataloader = CycledDataLoader(prompt_dataloader) self._before_fit(*args, **kwargs)
self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader)
with self._fit_ctx(): with self._fit_ctx():
for episode in tqdm.trange(num_episodes, desc="Episodes", disable=not is_rank_0()): for episode in tqdm.trange(num_episodes, desc="Episodes", disable=not is_rank_0()):
with self._episode_ctx(episode): with self._episode_ctx(episode):

View File

@ -35,5 +35,5 @@ class Callback(ABC):
def on_learn_batch_start(self) -> None: def on_learn_batch_start(self) -> None:
pass pass
def on_learn_batch_end(self, metrics: dict, experience: Experience) -> None: def on_learn_batch_end(self, experience: Experience) -> None:
pass pass

View File

@ -137,7 +137,7 @@ class PerformanceEvaluator(Callback):
return return
self.learn_timer.start() self.learn_timer.start()
def on_learn_batch_end(self, metrics: dict, experience: Experience) -> None: def on_learn_batch_end(self, experience: Experience) -> None:
if self.disable: if self.disable:
return return
self.learn_timer.end() self.learn_timer.end()

View File

@ -1,27 +1,26 @@
from typing import Dict, List from typing import Dict, List, Optional
import torch.nn as nn
from coati.experience_buffer import NaiveExperienceBuffer from coati.experience_buffer import NaiveExperienceBuffer
from coati.experience_maker import Experience, NaiveExperienceMaker from coati.experience_maker import Experience, NaiveExperienceMaker
from coati.models.base import Actor, Critic, get_base_model from coati.models.base import Actor, Critic, RewardModel, get_base_model
from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss
from coati.models.utils import calc_action_log_probs from coati.models.utils import calc_action_log_probs
from torch import Tensor
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils.data import DistributedSampler from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm from tqdm import tqdm
from transformers import PreTrainedTokenizerBase
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from .base import OnPolicyTrainer from .base import OnPolicyTrainer
from .callbacks import Callback from .callbacks import Callback
from .strategies import GeminiStrategy, Strategy from .strategies import GeminiStrategy, Strategy
from .utils import is_rank_0, to_device from .utils import CycledDataLoader, is_rank_0, to_device
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> Dict: def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> Dict:
unwrapper_model = strategy.unwrap_model(actor) unwrapped_model = strategy.unwrap_model(actor)
hf_model = get_base_model(unwrapper_model) hf_model = get_base_model(unwrapped_model)
new_kwargs = {**generate_kwargs} new_kwargs = {**generate_kwargs}
# use huggingface models method directly # use huggingface models method directly
if "prepare_inputs_fn" not in generate_kwargs and hasattr(hf_model, "prepare_inputs_for_generation"): if "prepare_inputs_fn" not in generate_kwargs and hasattr(hf_model, "prepare_inputs_for_generation"):
@ -41,7 +40,7 @@ class PPOTrainer(OnPolicyTrainer):
strategy (Strategy): the strategy to use for training strategy (Strategy): the strategy to use for training
actor (Actor): the actor model in ppo algorithm actor (Actor): the actor model in ppo algorithm
critic (Critic): the critic model in ppo algorithm critic (Critic): the critic model in ppo algorithm
reward_model (nn.Module): the reward model in rlhf algorithm to make reward of sentences reward_model (RewardModel): the reward model in rlhf algorithm to make reward of sentences
initial_model (Actor): the initial model in rlhf algorithm to generate reference logics to limit the update of actor initial_model (Actor): the initial model in rlhf algorithm to generate reference logics to limit the update of actor
actor_optim (Optimizer): the optimizer to use for actor model actor_optim (Optimizer): the optimizer to use for actor model
critic_optim (Optimizer): the optimizer to use for critic model critic_optim (Optimizer): the optimizer to use for critic model
@ -65,10 +64,11 @@ class PPOTrainer(OnPolicyTrainer):
strategy: Strategy, strategy: Strategy,
actor: Actor, actor: Actor,
critic: Critic, critic: Critic,
reward_model: nn.Module, reward_model: RewardModel,
initial_model: Actor, initial_model: Actor,
actor_optim: Optimizer, actor_optim: Optimizer,
critic_optim: Optimizer, critic_optim: Optimizer,
tokenizer: PreTrainedTokenizerBase,
kl_coef: float = 0.1, kl_coef: float = 0.1,
ptx_coef: float = 0.9, ptx_coef: float = 0.9,
train_batch_size: int = 8, train_batch_size: int = 8,
@ -90,11 +90,11 @@ class PPOTrainer(OnPolicyTrainer):
super().__init__(strategy, data_buffer, sample_buffer, dataloader_pin_memory, callbacks) super().__init__(strategy, data_buffer, sample_buffer, dataloader_pin_memory, callbacks)
self.generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor) self.generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef) self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, tokenizer, kl_coef)
self.offload_inference_models = offload_inference_models
self.actor = actor self.actor = actor
self.critic = critic self.critic = critic
self.tokenizer = tokenizer
self.actor_loss_fn = PolicyLoss(eps_clip) self.actor_loss_fn = PolicyLoss(eps_clip)
self.critic_loss_fn = ValueLoss(value_clip) self.critic_loss_fn = ValueLoss(value_clip)
@ -104,58 +104,81 @@ class PPOTrainer(OnPolicyTrainer):
self.actor_optim = actor_optim self.actor_optim = actor_optim
self.critic_optim = critic_optim self.critic_optim = critic_optim
self.offload_inference_models = offload_inference_models
self.device = get_current_device() self.device = get_current_device()
def _before_fit(
self,
prompt_dataloader: DataLoader,
pretrain_dataloader: DataLoader,
log_dir: Optional[str] = None,
use_wandb: bool = False,
):
"""
Args:
prompt_dataloader (DataLoader): the dataloader to use for prompt data
pretrain_dataloader (DataLoader): the dataloader to use for pretrain data
"""
self.prompt_dataloader = CycledDataLoader(prompt_dataloader)
self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader)
self.writer = None
if use_wandb and is_rank_0():
assert log_dir is not None, "log_dir must be provided when use_wandb is True"
import wandb
wandb.init(project="Coati-ppo", sync_tensorboard=True)
if log_dir is not None and is_rank_0():
import os
import time
from torch.utils.tensorboard import SummaryWriter
log_dir = os.path.join(log_dir, "ppo")
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
self.writer = SummaryWriter(log_dir=log_dir)
def _make_experience(self, collect_step: int) -> Experience: def _make_experience(self, collect_step: int) -> Experience:
prompts = self.prompt_dataloader.next() prompts = self.prompt_dataloader.next()
if self.offload_inference_models: if self.offload_inference_models:
# TODO(ver217): this may be controlled by strategy if they are prepared by strategy # TODO(ver217): this may be controlled by strategy if they are prepared by strategy
self.experience_maker.initial_model.to(self.device) self.experience_maker.initial_model.to(self.device)
self.experience_maker.reward_model.to(self.device) self.experience_maker.reward_model.to(self.device)
if isinstance(prompts, Tensor): assert isinstance(prompts, dict), f'Unsupported input type "{type(prompts)}"'
return self.experience_maker.make_experience(prompts, **self.generate_kwargs) return self.experience_maker.make_experience(**prompts, **self.generate_kwargs)
elif isinstance(prompts, dict):
return self.experience_maker.make_experience(**prompts, **self.generate_kwargs)
else:
raise ValueError(f'Unsupported input type "{type(prompts)}"')
def _training_step(self, experience: Experience) -> Dict[str, float]: def _training_step(self, experience: Experience):
self.actor.train() self.actor.train()
self.critic.train() self.critic.train()
# policy loss # policy loss
num_actions = experience.action_mask.size(1) num_actions = experience.action_log_probs.size(1)
actor_output = self.actor(experience.sequences, attention_mask=experience.attention_mask) actor_logits = self.actor(experience.sequences, experience.attention_mask)["logits"]
action_log_probs = calc_action_log_probs(actor_output, experience.sequences, num_actions) action_log_probs = calc_action_log_probs(actor_logits, experience.sequences, num_actions)
actor_loss = self.actor_loss_fn( actor_loss = self.actor_loss_fn(
action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
) )
actor_loss = (1 - self.ptx_coef) * actor_loss
self.strategy.backward(actor_loss, self.actor, self.actor_optim)
# ptx loss # ptx loss
if self.ptx_coef != 0: if self.ptx_coef != 0:
batch = self.pretrain_dataloader.next() batch = self.pretrain_dataloader.next()
batch = to_device(batch, self.device) batch = to_device(batch, self.device)
ptx_log_probs = self.actor(batch["input_ids"], attention_mask=batch["attention_mask"])["logits"] ptx_log_probs = self.actor(batch["input_ids"], batch["attention_mask"])["logits"]
ptx_loss = self.ptx_loss_fn(ptx_log_probs, batch["labels"]) ptx_loss = self.ptx_coef * self.ptx_loss_fn(ptx_log_probs, batch["labels"])
actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef) self.strategy.backward(ptx_loss, self.actor, self.actor_optim)
self.strategy.backward(actor_loss, self.actor, self.actor_optim)
self.strategy.optimizer_step(self.actor_optim) self.strategy.optimizer_step(self.actor_optim)
self.actor_optim.zero_grad() self.actor_optim.zero_grad()
# value loss # value loss
values = self.critic( values = self.critic(experience.sequences, attention_mask=experience.attention_mask)
experience.sequences, action_mask=experience.action_mask, attention_mask=experience.attention_mask critic_loss = self.critic_loss_fn(values, experience.values, experience.reward)
)
critic_loss = self.critic_loss_fn(
values, experience.values, experience.reward, action_mask=experience.action_mask
)
critic_loss = critic_loss * self.vf_coef critic_loss = critic_loss * self.vf_coef
self.strategy.backward(critic_loss, self.critic, self.critic_optim) self.strategy.backward(critic_loss, self.critic, self.critic_optim)
self.strategy.optimizer_step(self.critic_optim) self.strategy.optimizer_step(self.critic_optim)
self.critic_optim.zero_grad() self.critic_optim.zero_grad()
return {"reward": experience.reward.mean().item()}
def _learn(self, update_step: int): def _learn(self, update_step: int):
if self.offload_inference_models: if self.offload_inference_models:
self.experience_maker.initial_model.to("cpu") self.experience_maker.initial_model.to("cpu")
@ -166,8 +189,8 @@ class PPOTrainer(OnPolicyTrainer):
experience = self.data_buffer.sample() experience = self.data_buffer.sample()
self._on_learn_batch_start() self._on_learn_batch_start()
experience.to_device(self.device) experience.to_device(self.device)
metrics = self._training_step(experience) self._training_step(experience)
self._on_learn_batch_end(metrics, experience) self._on_learn_batch_end(experience)
else: else:
if isinstance(self.dataloader.sampler, DistributedSampler): if isinstance(self.dataloader.sampler, DistributedSampler):
self.dataloader.sampler.set_epoch(update_step) self.dataloader.sampler.set_epoch(update_step)
@ -175,6 +198,5 @@ class PPOTrainer(OnPolicyTrainer):
for experience in pbar: for experience in pbar:
self._on_learn_batch_start() self._on_learn_batch_start()
experience.to_device(self.device) experience.to_device(self.device)
metrics = self._training_step(experience) self._training_step(experience)
self._on_learn_batch_end(metrics, experience) self._on_learn_batch_end(experience)
pbar.set_postfix(metrics)

View File

@ -1,7 +1,5 @@
from datetime import datetime from typing import Callable, Optional
from typing import Callable
import pandas as pd
import torch import torch
import tqdm import tqdm
from torch.optim import Optimizer from torch.optim import Optimizer
@ -40,10 +38,12 @@ class RewardModelTrainer(SLTrainer):
self.loss_fn = loss_fn self.loss_fn = loss_fn
self.scheduler = lr_scheduler self.scheduler = lr_scheduler
self.num_train_step = 0
def _eval(self, epoch): def _eval(self, epoch):
if self.eval_dataloader is not None: if self.eval_dataloader is not None:
self.model.eval() self.model.eval()
dist, on, cnt = 0, 0, 0 dist, num_correct, num_samples = 0, 0, 0
with torch.no_grad(): with torch.no_grad():
for chosen_ids, c_mask, reject_ids, r_mask in self.eval_dataloader: for chosen_ids, c_mask, reject_ids, r_mask in self.eval_dataloader:
chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device()) chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
@ -52,27 +52,21 @@ class RewardModelTrainer(SLTrainer):
r_mask = r_mask.squeeze(1).to(torch.cuda.current_device()) r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
chosen_reward = self.model(chosen_ids, attention_mask=c_mask) chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
reject_reward = self.model(reject_ids, attention_mask=r_mask) reject_reward = self.model(reject_ids, attention_mask=r_mask)
for i in range(len(chosen_reward)): num_samples += chosen_ids.size(0)
cnt += 1 num_correct += (chosen_reward > reject_reward).sum().item()
if chosen_reward[i] > reject_reward[i]:
on += 1
dist += (chosen_reward - reject_reward).mean().item() dist += (chosen_reward - reject_reward).mean().item()
self.dist = dist / len(self.eval_dataloader) self.dist = dist / len(self.eval_dataloader)
self.acc = on / cnt self.acc = num_correct / num_samples
if is_rank_0(): if self.writer:
log = pd.DataFrame( self.writer.add_scalar("eval/dist", self.dist, epoch)
[[(epoch + 1) * len(self.train_dataloader), self.loss.item(), self.dist, self.acc]], self.writer.add_scalar("eval/acc", self.acc, epoch)
columns=["step", "loss", "dist", "acc"],
)
log.to_csv("log.csv", mode="a", header=False, index=False)
def _train(self, epoch): def _train(self, epoch):
self.model.train() self.model.train()
step_bar = tqdm.trange( step_bar = tqdm.trange(
len(self.train_dataloader), desc="Train step of epoch %d" % epoch, disable=not is_rank_0() len(self.train_dataloader), desc=f"Epoch {epoch + 1}/{self.max_epochs}", disable=not is_rank_0()
) )
cnt = 0
for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader: for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader:
chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device()) chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
c_mask = c_mask.squeeze(1).to(torch.cuda.current_device()) c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
@ -80,26 +74,50 @@ class RewardModelTrainer(SLTrainer):
r_mask = r_mask.squeeze(1).to(torch.cuda.current_device()) r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
chosen_reward = self.model(chosen_ids, attention_mask=c_mask) chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
reject_reward = self.model(reject_ids, attention_mask=r_mask) reject_reward = self.model(reject_ids, attention_mask=r_mask)
self.loss = self.loss_fn(chosen_reward, reject_reward) loss = self.loss_fn(chosen_reward, reject_reward)
self.strategy.backward(self.loss, self.model, self.optimizer) self.strategy.backward(loss, self.model, self.optimizer)
self.strategy.optimizer_step(self.optimizer) self.strategy.optimizer_step(self.optimizer)
self.optimizer.zero_grad() self.optimizer.zero_grad()
cnt += 1 if self.writer:
if cnt % 100 == 0: self.writer.add_scalar("train/loss", loss.item(), self.num_train_step)
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
self.writer.add_scalar("train/dist", (chosen_reward - reject_reward).mean().item(), self.num_train_step)
self.writer.add_scalar(
"train/acc", (chosen_reward > reject_reward).float().mean().item(), self.num_train_step
)
self.num_train_step += 1
if self.num_train_step % 100 == 0:
self.scheduler.step() self.scheduler.step()
step_bar.update() step_bar.update()
step_bar.close() step_bar.close()
def _before_fit(self, train_dataloader: DataLoader, valid_dataloader: DataLoader, eval_dataloader: DataLoader): def _before_fit(
self,
train_dataloader: DataLoader,
eval_dataloader: DataLoader,
log_dir: Optional[str] = None,
use_wandb: bool = False,
):
""" """
Args: Args:
train_dataloader (DataLoader): the dataloader to use for training train_dataloader (DataLoader): the dataloader to use for training
valid_dataloader (DataLoader): the dataloader to use for validation
eval_dataloader (DataLoader): the dataloader to use for evaluation eval_dataloader (DataLoader): the dataloader to use for evaluation
""" """
super()._before_fit()
self.datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
self.train_dataloader = train_dataloader self.train_dataloader = train_dataloader
self.valid_dataloader = valid_dataloader
self.eval_dataloader = eval_dataloader self.eval_dataloader = eval_dataloader
self.writer = None
if use_wandb and is_rank_0():
assert log_dir is not None, "log_dir must be provided when use_wandb is True"
import wandb
wandb.init(project="Coati-rm", sync_tensorboard=True)
if log_dir is not None and is_rank_0():
import os
import time
from torch.utils.tensorboard import SummaryWriter
log_dir = os.path.join(log_dir, "rm")
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
self.writer = SummaryWriter(log_dir=log_dir)

View File

@ -1,10 +1,8 @@
import time
from typing import Optional from typing import Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import tqdm import tqdm
import wandb
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -48,38 +46,34 @@ class SFTTrainer(SLTrainer):
self.accumulation_steps = accumulation_steps self.accumulation_steps = accumulation_steps
self.scheduler = lr_scheduler self.scheduler = lr_scheduler
self.num_train_step = 0
self.num_eval_step = 0
def _train(self, epoch: int): def _train(self, epoch: int):
self.model.train() self.model.train()
for batch_id, batch in enumerate(self.train_dataloader): step_bar = tqdm.trange(
len(self.train_dataloader) // self.accumulation_steps,
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
disable=not is_rank_0(),
)
for i, batch in enumerate(self.train_dataloader):
batch = to_device(batch, torch.cuda.current_device()) batch = to_device(batch, torch.cuda.current_device())
if "attention_mask" in batch: outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]) loss = outputs.loss / self.accumulation_steps
else:
outputs = self.model(batch["input_ids"], labels=batch["labels"])
loss = outputs.loss
loss = loss / self.accumulation_steps
self.strategy.backward(loss, self.model, self.optimizer)
self.total_loss += loss.item() self.total_loss += loss.item()
self.strategy.backward(loss, self.model, self.optimizer)
# gradient accumulation # gradient accumulation
if (batch_id + 1) % self.accumulation_steps == 0: if (i + 1) % self.accumulation_steps == 0:
self.strategy.optimizer_step(self.optimizer) self.strategy.optimizer_step(self.optimizer)
self.optimizer.zero_grad() self.optimizer.zero_grad()
self.scheduler.step() self.scheduler.step()
if is_rank_0() and self.use_wandb: if self.writer:
wandb.log( self.writer.add_scalar("train/loss", self.total_loss, self.num_train_step)
{ self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step)
"loss": self.total_loss / self.accumulation_steps, self.num_train_step += 1
"lr": self.scheduler.get_last_lr()[0],
"epoch": epoch,
"batch_id": batch_id,
}
)
self.total_loss = 0 self.total_loss = 0
self.step_bar.update() step_bar.update()
step_bar.close()
def _eval(self, epoch: int): def _eval(self, epoch: int):
if self.eval_dataloader is not None: if self.eval_dataloader is not None:
@ -91,20 +85,21 @@ class SFTTrainer(SLTrainer):
outputs = self.model( outputs = self.model(
batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"] batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]
) )
loss = outputs.loss loss_sum += outputs.loss.item()
loss_sum += loss.item()
num_seen += batch["input_ids"].size(0) num_seen += batch["input_ids"].size(0)
loss_mean = loss_sum / num_seen loss_mean = loss_sum / num_seen
if dist.get_rank() == 0: if dist.get_rank() == 0:
self.logger.info(f"Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}") self.logger.info(f"Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}")
if self.writer:
self.writer.add_scalar("eval/loss", loss_mean, self.num_eval_step)
self.num_eval_step += 1
def _before_fit( def _before_fit(
self, self,
train_dataloader: DataLoader, train_dataloader: DataLoader,
eval_dataloader: Optional[DataLoader] = None, eval_dataloader: Optional[DataLoader] = None,
logger: Optional[DistributedLogger] = None, logger: Optional[DistributedLogger] = None,
log_dir: Optional[str] = None,
use_wandb: bool = False, use_wandb: bool = False,
): ):
""" """
@ -116,15 +111,20 @@ class SFTTrainer(SLTrainer):
self.eval_dataloader = eval_dataloader self.eval_dataloader = eval_dataloader
self.logger = logger self.logger = logger
self.use_wandb = use_wandb self.writer = None
if use_wandb: if use_wandb and is_rank_0():
wandb.init(project="Coati", name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) assert log_dir is not None, "log_dir must be provided when use_wandb is True"
wandb.watch(self.model) import wandb
wandb.init(project="Coati-sft", sync_tensorboard=True)
if log_dir is not None and is_rank_0():
import os
import time
from torch.utils.tensorboard import SummaryWriter
log_dir = os.path.join(log_dir, "sft")
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
self.writer = SummaryWriter(log_dir=log_dir)
self.total_loss = 0 self.total_loss = 0
self.no_epoch_bar = True
self.step_bar = tqdm.trange(
len(self.train_dataloader) // self.accumulation_steps * self.max_epochs,
desc=f"steps",
disable=not is_rank_0(),
)

View File

@ -1,17 +1,13 @@
import warnings import warnings
from typing import Optional from typing import Optional
import torch
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
import colossalai import colossalai
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin
from colossalai.booster.plugin.gemini_plugin import GeminiModel
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.tensor import ProcessGroup, ShardSpec from colossalai.lazy.lazy_init import LazyInitContext
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext
from colossalai.zero.gemini.gemini_ddp import GeminiDDP from colossalai.zero.gemini.gemini_ddp import GeminiDDP
from .ddp import DDPStrategy from .ddp import DDPStrategy
@ -65,14 +61,11 @@ class LowLevelZeroStrategy(DDPStrategy):
assert precision in ("fp32", "fp16"), f'Unsupported precision "{precision}"' assert precision in ("fp32", "fp16"), f'Unsupported precision "{precision}"'
plugin_initializer = lambda: LowLevelZeroPlugin( plugin_initializer = lambda: LowLevelZeroPlugin(
# zero_config
stage=stage, stage=stage,
precision=precision, precision=precision,
# zero_optim_config
reduce_bucket_size_in_m=reduce_bucket_size, reduce_bucket_size_in_m=reduce_bucket_size,
overlap_communication=overlap_communication, overlap_communication=overlap_communication,
cpu_offload=(placement_policy == "cpu"), cpu_offload=(placement_policy == "cpu"),
# optim_config
initial_scale=initial_scale, initial_scale=initial_scale,
growth_factor=growth_factor, growth_factor=growth_factor,
backoff_factor=backoff_factor, backoff_factor=backoff_factor,
@ -136,7 +129,7 @@ class GeminiStrategy(DDPStrategy):
self, self,
seed: int = 42, seed: int = 42,
shard_init: bool = False, # only for stage 3 shard_init: bool = False, # only for stage 3
placement_policy: str = "cuda", placement_policy: str = "auto",
pin_memory: bool = True, # only for stage 3 pin_memory: bool = True, # only for stage 3
force_outputs_fp32: bool = False, # only for stage 3 force_outputs_fp32: bool = False, # only for stage 3
search_range_m: int = 32, # only for stage 3 search_range_m: int = 32, # only for stage 3
@ -153,8 +146,6 @@ class GeminiStrategy(DDPStrategy):
max_norm: float = 0.0, max_norm: float = 0.0,
norm_type: float = 2.0, norm_type: float = 2.0,
) -> None: ) -> None:
assert placement_policy in ("cpu", "cuda"), f'Unsupported placement policy "{placement_policy}"'
# TODO(ver217): support shard_init when using from_pretrained() # TODO(ver217): support shard_init when using from_pretrained()
if shard_init: if shard_init:
warnings.warn( warnings.warn(
@ -167,8 +158,7 @@ class GeminiStrategy(DDPStrategy):
# NOTE: dist should be initialized before calling get_current_device() # NOTE: dist should be initialized before calling get_current_device()
plugin_initializer = lambda: GeminiPlugin( plugin_initializer = lambda: GeminiPlugin(
# gemini_config chunk_init_device=get_current_device(),
device=get_current_device(),
placement_policy=placement_policy, placement_policy=placement_policy,
precision="fp16", precision="fp16",
pin_memory=pin_memory, pin_memory=pin_memory,
@ -177,9 +167,7 @@ class GeminiStrategy(DDPStrategy):
search_range_m=search_range_m, search_range_m=search_range_m,
hidden_dim=hidden_dim, hidden_dim=hidden_dim,
min_chunk_size_m=min_chunk_size_m, min_chunk_size_m=min_chunk_size_m,
# zero_optim_config
gpu_margin_mem_ratio=gpu_margin_mem_ratio, gpu_margin_mem_ratio=gpu_margin_mem_ratio,
# optim_config
initial_scale=initial_scale, initial_scale=initial_scale,
growth_factor=growth_factor, growth_factor=growth_factor,
backoff_factor=backoff_factor, backoff_factor=backoff_factor,
@ -200,15 +188,8 @@ class GeminiStrategy(DDPStrategy):
colossalai.launch_from_torch({}, seed=self.seed) colossalai.launch_from_torch({}, seed=self.seed)
def model_init_context(self): def model_init_context(self):
world_size = dist.get_world_size() return LazyInitContext(default_device=get_current_device())
shard_pg = ProcessGroup(tp_degree=world_size) if self.shard_init else None
default_dist_spec = ShardSpec([-1], [world_size]) if self.shard_init else None
return ColoInitContext(
device=get_current_device(), dtype=torch.half, default_pg=shard_pg, default_dist_spec=default_dist_spec
)
def unwrap_model(self, model: nn.Module) -> nn.Module: def unwrap_model(self, model: nn.Module) -> nn.Module:
assert isinstance(model, GeminiModel) assert isinstance(model, GeminiDDP)
ddp_model = model.unwrap() return model.module
assert isinstance(ddp_model, GeminiDDP)
return ddp_model.module

View File

@ -45,9 +45,17 @@ def eval(args):
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
actor.eval() actor.eval()
tokenizer.padding_side = "left"
input_ids = tokenizer.encode(args.input, return_tensors="pt").to(torch.cuda.current_device()) input_ids = tokenizer.encode(args.input, return_tensors="pt").to(torch.cuda.current_device())
outputs = generate( outputs = generate(
actor, input_ids, max_length=args.max_length, do_sample=True, top_k=50, top_p=0.95, num_return_sequences=1 actor,
input_ids,
tokenizer=tokenizer,
max_length=args.max_length,
do_sample=True,
top_k=50,
top_p=0.95,
num_return_sequences=1,
) )
output = tokenizer.batch_decode(outputs[0], skip_special_tokens=True) output = tokenizer.batch_decode(outputs[0], skip_special_tokens=True)
print(f"[Output]: {''.join(output)}") print(f"[Output]: {''.join(output)}")

View File

@ -1,3 +1,3 @@
pandas>=1.4.1 pandas>=1.4.1
sentencepiece sentencepiece
colossalai==0.3.1 colossalai>=0.3.1

View File

@ -23,7 +23,7 @@ def main(args):
if args.strategy == "ddp": if args.strategy == "ddp":
strategy = DDPStrategy() strategy = DDPStrategy()
elif args.strategy == "colossalai_gemini": elif args.strategy == "colossalai_gemini":
strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5) strategy = GeminiStrategy(placement_policy="auto", initial_scale=2**5)
elif args.strategy == "colossalai_zero2": elif args.strategy == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda") strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
else: else:
@ -65,8 +65,8 @@ def main(args):
if args.rm_path is not None: if args.rm_path is not None:
reward_model.load_state_dict(state_dict, strict=False) reward_model.load_state_dict(state_dict, strict=False)
initial_model.to(torch.float16).to(torch.cuda.current_device()) initial_model.to(torch.bfloat16).to(torch.cuda.current_device())
reward_model.to(torch.float16).to(torch.cuda.current_device()) reward_model.to(torch.bfloat16).to(torch.cuda.current_device())
if args.model == "gpt2": if args.model == "gpt2":
actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank) actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
@ -80,13 +80,13 @@ def main(args):
raise ValueError(f'Unsupported actor model "{args.model}"') raise ValueError(f'Unsupported actor model "{args.model}"')
if rm_model_name == "gpt2": if rm_model_name == "gpt2":
critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
elif rm_model_name == "bloom": elif rm_model_name == "bloom":
critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
elif rm_model_name == "opt": elif rm_model_name == "opt":
critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
elif rm_model_name == "llama": elif rm_model_name == "llama":
critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
else: else:
raise ValueError(f'Unsupported reward model "{rm_model_name}"') raise ValueError(f'Unsupported reward model "{rm_model_name}"')
@ -94,17 +94,16 @@ def main(args):
critic.load_state_dict(state_dict, strict=False) critic.load_state_dict(state_dict, strict=False)
del state_dict del state_dict
if args.strategy != "colossalai_gemini": actor.to(torch.bfloat16).to(torch.cuda.current_device())
critic.to(torch.float16).to(torch.cuda.current_device()) critic.to(torch.bfloat16).to(torch.cuda.current_device())
actor.to(torch.float16).to(torch.cuda.current_device())
# configure optimizer # configure optimizer
if args.strategy.startswith("colossalai"): if args.strategy.startswith("colossalai"):
actor_optim = HybridAdam(actor.parameters(), lr=1e-7) actor_optim = HybridAdam(actor.parameters(), lr=args.lr)
critic_optim = HybridAdam(critic.parameters(), lr=1e-7) critic_optim = HybridAdam(critic.parameters(), lr=args.lr)
else: else:
actor_optim = Adam(actor.parameters(), lr=1e-7) actor_optim = Adam(actor.parameters(), lr=args.lr)
critic_optim = Adam(critic.parameters(), lr=1e-7) critic_optim = Adam(critic.parameters(), lr=args.lr)
# configure tokenizer # configure tokenizer
if args.model == "gpt2": if args.model == "gpt2":
@ -126,8 +125,15 @@ def main(args):
tokenizer.pad_token = tokenizer.unk_token tokenizer.pad_token = tokenizer.unk_token
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
# NOTE: generate() requires padding_side to be "left"
tokenizer.padding_side = "left"
prompt_dataset = PromptDataset(tokenizer=tokenizer, data_path=args.prompt_dataset, max_datasets_size=16384) prompt_dataset = PromptDataset(
tokenizer=tokenizer,
data_path=args.prompt_dataset,
max_datasets_size=args.max_datasets_size,
max_length=args.max_input_len,
)
if dist.is_initialized() and dist.get_world_size() > 1: if dist.is_initialized() and dist.get_world_size() > 1:
prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True) prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)
else: else:
@ -137,7 +143,10 @@ def main(args):
) )
pretrain_dataset = SupervisedDataset( pretrain_dataset = SupervisedDataset(
tokenizer=tokenizer, data_path=args.pretrain_dataset, max_datasets_size=16384, max_length=args.max_input_len tokenizer=tokenizer,
data_path=args.pretrain_dataset,
max_datasets_size=args.max_datasets_size,
max_length=args.max_input_len,
) )
if dist.is_initialized() and dist.get_world_size() > 1: if dist.is_initialized() and dist.get_world_size() > 1:
pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True) pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True)
@ -161,6 +170,7 @@ def main(args):
initial_model, initial_model,
actor_optim, actor_optim,
critic_optim, critic_optim,
tokenizer=tokenizer,
kl_coef=args.kl_coef, kl_coef=args.kl_coef,
ptx_coef=args.ptx_coef, ptx_coef=args.ptx_coef,
train_batch_size=args.train_batch_size, train_batch_size=args.train_batch_size,
@ -169,17 +179,17 @@ def main(args):
do_sample=True, do_sample=True,
temperature=1.0, temperature=1.0,
top_k=50, top_k=50,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
offload_inference_models=args.strategy != "colossalai_gemini", offload_inference_models=args.strategy != "colossalai_gemini",
) )
trainer.fit( trainer.fit(
prompt_dataloader=prompt_dataloader,
pretrain_dataloader=pretrain_dataloader,
num_episodes=args.num_episodes, num_episodes=args.num_episodes,
num_collect_steps=args.num_collect_steps, num_collect_steps=args.num_collect_steps,
num_update_steps=args.num_update_steps, num_update_steps=args.num_update_steps,
prompt_dataloader=prompt_dataloader,
pretrain_dataloader=pretrain_dataloader,
log_dir=args.log_dir,
use_wandb=args.use_wandb,
) )
# save model checkpoint after fitting # save model checkpoint after fitting
@ -195,6 +205,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--prompt_dataset", type=str, default=None, help="path to the prompt dataset") parser.add_argument("--prompt_dataset", type=str, default=None, help="path to the prompt dataset")
parser.add_argument("--pretrain_dataset", type=str, default=None, help="path to the pretrained dataset") parser.add_argument("--pretrain_dataset", type=str, default=None, help="path to the pretrained dataset")
parser.add_argument("--max_datasets_size", type=int, default=50000)
parser.add_argument( parser.add_argument(
"--strategy", "--strategy",
choices=["ddp", "colossalai_gemini", "colossalai_zero2"], choices=["ddp", "colossalai_gemini", "colossalai_zero2"],
@ -216,9 +227,12 @@ if __name__ == "__main__":
parser.add_argument("--ptx_batch_size", type=int, default=1) parser.add_argument("--ptx_batch_size", type=int, default=1)
parser.add_argument("--experience_batch_size", type=int, default=8) parser.add_argument("--experience_batch_size", type=int, default=8)
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument("--lr", type=float, default=1e-7)
parser.add_argument("--kl_coef", type=float, default=0.1) parser.add_argument("--kl_coef", type=float, default=0.1)
parser.add_argument("--ptx_coef", type=float, default=0.9) parser.add_argument("--ptx_coef", type=float, default=0.9)
parser.add_argument("--max_input_len", type=int, default=96) parser.add_argument("--max_input_len", type=int, default=96)
parser.add_argument("--max_seq_len", type=int, default=128) parser.add_argument("--max_seq_len", type=int, default=128)
parser.add_argument("--log_dir", default="logs", type=str)
parser.add_argument("--use_wandb", default=False, action="store_true")
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@ -1,5 +1,4 @@
import argparse import argparse
from random import randint
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -27,7 +26,7 @@ def train(args):
if args.strategy == "ddp": if args.strategy == "ddp":
strategy = DDPStrategy() strategy = DDPStrategy()
elif args.strategy == "colossalai_gemini": elif args.strategy == "colossalai_gemini":
strategy = GeminiStrategy(placement_policy="cuda") strategy = GeminiStrategy(placement_policy="auto")
elif args.strategy == "colossalai_zero2": elif args.strategy == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda") strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
else: else:
@ -46,7 +45,7 @@ def train(args):
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
model.to(torch.float16).to(torch.cuda.current_device()) model.to(torch.bfloat16).to(torch.cuda.current_device())
if args.model_path is not None: if args.model_path is not None:
state_dict = torch.load(args.model_path) state_dict = torch.load(args.model_path)
@ -75,9 +74,9 @@ def train(args):
# configure optimizer # configure optimizer
if args.strategy.startswith("colossalai"): if args.strategy.startswith("colossalai"):
optim = HybridAdam(model.parameters(), lr=5e-6) optim = HybridAdam(model.parameters(), lr=args.lr)
else: else:
optim = Adam(model.parameters(), lr=5e-6) optim = Adam(model.parameters(), lr=args.lr)
# configure loss function # configure loss function
if args.loss_fn == "log_sig": if args.loss_fn == "log_sig":
@ -93,21 +92,14 @@ def train(args):
else: else:
data = load_dataset(args.dataset) data = load_dataset(args.dataset)
if args.test: train_data = data["train"].select(range(min(args.max_datasets_size, len(data["train"]))))
train_data = data["train"].select(range(20)) eval_data = data["test"].select(range(min(args.max_datasets_size, len(data["test"]))))
eval_data = data["test"].select(range(5))
else:
train_data = data["train"]
eval_data = data["test"]
valid_data = data["test"].select((randint(0, len(eval_data) - 1) for _ in range(len(eval_data) // 5)))
if args.dataset == "Dahoas/rm-static": if args.dataset == "Dahoas/rm-static":
train_dataset = RmStaticDataset(train_data, tokenizer, args.max_len) train_dataset = RmStaticDataset(train_data, tokenizer, args.max_len)
valid_dataset = RmStaticDataset(valid_data, tokenizer, args.max_len)
eval_dataset = RmStaticDataset(eval_data, tokenizer, args.max_len) eval_dataset = RmStaticDataset(eval_data, tokenizer, args.max_len)
elif args.dataset == "Anthropic/hh-rlhf": elif args.dataset == "Anthropic/hh-rlhf":
train_dataset = HhRlhfDataset(train_data, tokenizer, args.max_len) train_dataset = HhRlhfDataset(train_data, tokenizer, args.max_len)
valid_dataset = HhRlhfDataset(valid_data, tokenizer, args.max_len)
eval_dataset = HhRlhfDataset(eval_data, tokenizer, args.max_len) eval_dataset = HhRlhfDataset(eval_data, tokenizer, args.max_len)
else: else:
raise ValueError(f'Unsupported dataset "{args.dataset}"') raise ValueError(f'Unsupported dataset "{args.dataset}"')
@ -121,14 +113,6 @@ def train(args):
rank=dist.get_rank(), rank=dist.get_rank(),
num_replicas=dist.get_world_size(), num_replicas=dist.get_world_size(),
) )
valid_sampler = DistributedSampler(
valid_dataset,
shuffle=True,
seed=42,
drop_last=True,
rank=dist.get_rank(),
num_replicas=dist.get_world_size(),
)
eval_sampler = DistributedSampler( eval_sampler = DistributedSampler(
eval_dataset, eval_dataset,
shuffle=True, shuffle=True,
@ -139,7 +123,6 @@ def train(args):
) )
else: else:
train_sampler = None train_sampler = None
valid_sampler = None
eval_sampler = None eval_sampler = None
train_dataloader = DataLoader( train_dataloader = DataLoader(
@ -150,14 +133,6 @@ def train(args):
pin_memory=True, pin_memory=True,
) )
valid_dataloader = DataLoader(
valid_dataset,
shuffle=(valid_sampler is None),
sampler=valid_sampler,
batch_size=args.batch_size,
pin_memory=True,
)
eval_dataloader = DataLoader( eval_dataloader = DataLoader(
eval_dataset, shuffle=(eval_sampler is None), sampler=eval_sampler, batch_size=args.batch_size, pin_memory=True eval_dataset, shuffle=(eval_sampler is None), sampler=eval_sampler, batch_size=args.batch_size, pin_memory=True
) )
@ -176,7 +151,12 @@ def train(args):
max_epochs=args.max_epochs, max_epochs=args.max_epochs,
) )
trainer.fit(train_dataloader=train_dataloader, valid_dataloader=valid_dataloader, eval_dataloader=eval_dataloader) trainer.fit(
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
log_dir=args.log_dir,
use_wandb=args.use_wandb,
)
# save model checkpoint after fitting on only rank0 # save model checkpoint after fitting on only rank0
strategy.save_model(model, args.save_path, only_rank0=True) strategy.save_model(model, args.save_path, only_rank0=True)
# save optimizer checkpoint on all ranks # save optimizer checkpoint on all ranks
@ -200,12 +180,15 @@ if __name__ == "__main__":
"--dataset", type=str, choices=["Anthropic/hh-rlhf", "Dahoas/rm-static"], default="Dahoas/rm-static" "--dataset", type=str, choices=["Anthropic/hh-rlhf", "Dahoas/rm-static"], default="Dahoas/rm-static"
) )
parser.add_argument("--subset", type=lambda x: None if x == "None" else x, default=None) parser.add_argument("--subset", type=lambda x: None if x == "None" else x, default=None)
parser.add_argument("--max_datasets_size", type=int, default=1000000)
parser.add_argument("--save_path", type=str, default="rm_ckpt") parser.add_argument("--save_path", type=str, default="rm_ckpt")
parser.add_argument("--max_epochs", type=int, default=1) parser.add_argument("--max_epochs", type=int, default=1)
parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--max_len", type=int, default=512) parser.add_argument("--max_len", type=int, default=512)
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument("--lr", type=float, default=9e-6)
parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"]) parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"])
parser.add_argument("--test", type=bool, default=False) parser.add_argument("--log_dir", default="logs", type=str)
parser.add_argument("--use_wandb", default=False, action="store_true")
args = parser.parse_args() args = parser.parse_args()
train(args) train(args)

View File

@ -16,7 +16,10 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
set_n_least_used_CUDA_VISIBLE_DEVICES 2 set_n_least_used_CUDA_VISIBLE_DEVICES 2
torchrun --standalone --nproc_per_node=2 train_reward_model.py \ torchrun --standalone --nproc_per_node=2 train_reward_model.py \
--model 'bloom' \ --pretrain 'gpt2' \
--model 'gpt2' \
--strategy colossalai_zero2 \ --strategy colossalai_zero2 \
--loss_fn 'log_sig' \ --loss_fn 'log_exp' \
--dataset 'Anthropic/hh-rlhf' --dataset 'Anthropic/hh-rlhf' \
--batch_size 16 \
--max_epochs 10

View File

@ -23,7 +23,6 @@ from transformers.trainer import get_scheduler
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor import ColoParameter
def train(args): def train(args):
@ -31,7 +30,7 @@ def train(args):
if args.strategy == "ddp": if args.strategy == "ddp":
strategy = DDPStrategy() strategy = DDPStrategy()
elif args.strategy == "colossalai_gemini": elif args.strategy == "colossalai_gemini":
strategy = GeminiStrategy(placement_policy="cuda") strategy = GeminiStrategy(placement_policy="auto")
elif args.strategy == "colossalai_zero2": elif args.strategy == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda") strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
elif args.strategy == "colossalai_zero2_cpu": elif args.strategy == "colossalai_zero2_cpu":
@ -57,7 +56,7 @@ def train(args):
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
model.to(torch.float16).to(torch.cuda.current_device()) model.to(torch.bfloat16).to(torch.cuda.current_device())
# configure tokenizer # configure tokenizer
if args.model == "gpt2": if args.model == "gpt2":
@ -84,28 +83,21 @@ def train(args):
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
if args.model == "llama" and args.strategy == "colossalai_gemini":
# this is a hack to deal with the resized embedding
# to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatibility
for name, param in model.named_parameters():
if not isinstance(param, ColoParameter):
sub_module_name = ".".join(name.split(".")[:-1])
weight_name = name.split(".")[-1]
sub_module = model.get_submodule(sub_module_name)
setattr(sub_module, weight_name, ColoParameter(param))
# configure optimizer # configure optimizer
if args.strategy.startswith("colossalai"): if args.strategy.startswith("colossalai"):
optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0) optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0)
else: else:
optim = Adam(model.parameters(), lr=args.lr) optim = Adam(model.parameters(), lr=args.lr)
logger = get_dist_logger()
# configure dataset # configure dataset
if args.dataset == "yizhongw/self_instruct": if args.dataset == "yizhongw/self_instruct":
train_data = load_dataset(args.dataset, "super_natural_instructions", split="train") train_data = load_dataset(args.dataset, "super_natural_instructions", split="train")
eval_data = load_dataset(args.dataset, "super_natural_instructions", split="test") eval_data = load_dataset(args.dataset, "super_natural_instructions", split="test")
if args.max_datasets_size is not None:
train_data = train_data.select(range(min(args.max_datasets_size, len(train_data))))
eval_data = eval_data.select(range(min(args.max_datasets_size, len(eval_data))))
train_dataset = SFTDataset(train_data, tokenizer, args.max_len) train_dataset = SFTDataset(train_data, tokenizer, args.max_len)
eval_dataset = SFTDataset(eval_data, tokenizer, args.max_len) eval_dataset = SFTDataset(eval_data, tokenizer, args.max_len)
@ -176,8 +168,13 @@ def train(args):
accumulation_steps=args.accumulation_steps, accumulation_steps=args.accumulation_steps,
) )
logger = get_dist_logger()
trainer.fit( trainer.fit(
train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, logger=logger, use_wandb=args.use_wandb train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
logger=logger,
log_dir=args.log_dir,
use_wandb=args.use_wandb,
) )
# save model checkpoint after fitting on only rank0 # save model checkpoint after fitting on only rank0
@ -207,9 +204,9 @@ if __name__ == "__main__":
parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--max_len", type=int, default=512) parser.add_argument("--max_len", type=int, default=512)
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument("--log_interval", type=int, default=100, help="how many steps to log")
parser.add_argument("--lr", type=float, default=5e-6) parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--accumulation_steps", type=int, default=8) parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument("--log_dir", default="logs", type=str)
parser.add_argument("--use_wandb", default=False, action="store_true") parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--grad_checkpoint", default=False, action="store_true") parser.add_argument("--grad_checkpoint", default=False, action="store_true")
args = parser.parse_args() args = parser.parse_args()

View File

@ -19,7 +19,6 @@ torchrun --standalone --nproc_per_node=4 train_sft.py \
--pretrain "/path/to/LLaMa-7B/" \ --pretrain "/path/to/LLaMa-7B/" \
--model 'llama' \ --model 'llama' \
--strategy colossalai_zero2 \ --strategy colossalai_zero2 \
--log_interval 10 \
--save_path /path/to/Coati-7B \ --save_path /path/to/Coati-7B \
--dataset /path/to/data.json \ --dataset /path/to/data.json \
--batch_size 4 \ --batch_size 4 \

View File

@ -1,2 +1,2 @@
pytest pytest
colossalai==0.3.1 colossalai>=0.3.1

View File

@ -2,7 +2,7 @@ transformers>=4.20.1
tqdm tqdm
datasets datasets
loralib loralib
colossalai==0.3.1 colossalai>=0.3.1
torch<2.0.0, >=1.12.1 torch<2.0.0, >=1.12.1
langchain langchain
tokenizers tokenizers
@ -11,3 +11,4 @@ sse_starlette
wandb wandb
sentencepiece sentencepiece
gpustat gpustat
tensorboard

View File

@ -25,8 +25,8 @@ def get_data(batch_size: int, seq_len: int = 10) -> dict:
def train_step(strategy: Strategy, actor: GPTActor, actor_optim: HybridAdam, batch_size: int = 8): def train_step(strategy: Strategy, actor: GPTActor, actor_optim: HybridAdam, batch_size: int = 8):
data = get_data(batch_size) data = get_data(batch_size)
action_mask = torch.ones_like(data["attention_mask"], dtype=torch.bool) action_mask = torch.ones_like(data["attention_mask"], dtype=torch.bool)
actor_output = actor(data["input_ids"], data["attention_mask"]) actor_logits = actor(data["input_ids"], data["attention_mask"])["logits"]
action_log_probs = calc_action_log_probs(actor_output, data["input_ids"], action_mask.size(1)) action_log_probs = calc_action_log_probs(actor_logits, data["input_ids"], action_mask.size(1))
loss = action_log_probs.sum() loss = action_log_probs.sum()
strategy.backward(loss, actor, actor_optim) strategy.backward(loss, actor, actor_optim)
strategy.optimizer_step(actor_optim) strategy.optimizer_step(actor_optim)
@ -36,7 +36,7 @@ def run_test_checkpoint(strategy_name: str, shard: bool):
if strategy_name == "ddp": if strategy_name == "ddp":
strategy = DDPStrategy() strategy = DDPStrategy()
elif strategy_name == "colossalai_gemini": elif strategy_name == "colossalai_gemini":
strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5) strategy = GeminiStrategy(placement_policy="auto", initial_scale=2**5)
elif strategy_name == "colossalai_zero2": elif strategy_name == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda") strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
else: else:

View File

@ -226,7 +226,9 @@ def test_sft_dataset(model: str, dataset_path: Optional[str], max_dataset_size:
check_content(input_ids.masked_select(attention_mask), tokenizer, model) check_content(input_ids.masked_select(attention_mask), tokenizer, model)
assert torch.all(attention_mask) assert torch.all(attention_mask)
ignore_mask = labels == IGNORE_INDEX ignore_mask = labels == IGNORE_INDEX
check_content(input_ids.masked_select(ignore_mask), tokenizer, model) prompt_mask = torch.logical_and(ignore_mask, attention_mask)
check_content(input_ids.masked_select(prompt_mask), tokenizer, model)
assert torch.all(input_ids.masked_select(ignore_mask ^ prompt_mask) == tokenizer.pad_token_id)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,5 +1,5 @@
import copy
import os import os
from copy import deepcopy
import pytest import pytest
import torch import torch
@ -8,6 +8,7 @@ from coati.experience_buffer import NaiveExperienceBuffer
from coati.experience_maker import NaiveExperienceMaker from coati.experience_maker import NaiveExperienceMaker
from coati.models.base import RewardModel from coati.models.base import RewardModel
from coati.models.gpt import GPTActor, GPTCritic from coati.models.gpt import GPTActor, GPTCritic
from coati.trainer.ppo import _set_default_generate_kwargs
from coati.trainer.strategies import DDPStrategy, GeminiStrategy from coati.trainer.strategies import DDPStrategy, GeminiStrategy
from coati.trainer.strategies.colossalai import LowLevelZeroStrategy from coati.trainer.strategies.colossalai import LowLevelZeroStrategy
from transformers.models.gpt2.configuration_gpt2 import GPT2Config from transformers.models.gpt2.configuration_gpt2 import GPT2Config
@ -42,27 +43,38 @@ def make_and_consume_experience(strategy):
elif strategy == "colossalai-zero2": elif strategy == "colossalai-zero2":
strategy = LowLevelZeroStrategy() strategy = LowLevelZeroStrategy()
elif strategy == "colossalai-gemini": elif strategy == "colossalai-gemini":
strategy = GeminiStrategy(placement_policy="cuda") strategy = GeminiStrategy(placement_policy="static")
else: else:
raise ValueError(f'Unsupported strategy "{strategy}"') raise ValueError(f'Unsupported strategy "{strategy}"')
actor = GPTActor(config=GPT_CONFIG).cuda() with strategy.model_init_context():
critic = GPTCritic(config=GPT_CONFIG).cuda() actor = GPTActor(config=GPT_CONFIG).cuda()
critic = GPTCritic(config=GPT_CONFIG).cuda()
initial_model = deepcopy(actor) initial_model = GPTActor(config=GPT_CONFIG).cuda()
reward_model = RewardModel(deepcopy(critic.model)).cuda() reward_model = RewardModel(model=copy.deepcopy(critic.model)).cuda()
experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model) actor, critic, initial_model, reward_model = strategy.prepare(actor, critic, initial_model, reward_model)
class MockTokenizer:
def __init__(self):
self.padding_side = "left"
self.eos_token_id = 0
self.pad_token_id = 0
tokenizer = MockTokenizer()
experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, tokenizer)
data_buffer = NaiveExperienceBuffer(SAMPLE_BATCH_SIZE, cpu_offload=False) data_buffer = NaiveExperienceBuffer(SAMPLE_BATCH_SIZE, cpu_offload=False)
generate_kwargs = dict(do_sample=True, max_length=16)
generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
# experience of all ranks should be the same # experience of all ranks should be the same
for _ in range(2): for _ in range(2):
data = get_data(EXPERIENCE_BATCH_SIZE) data = get_data(EXPERIENCE_BATCH_SIZE)
assert gather_and_equal(data["input_ids"]) assert gather_and_equal(data["input_ids"])
assert gather_and_equal(data["attention_mask"]) assert gather_and_equal(data["attention_mask"])
experience = experience_maker.make_experience( experience = experience_maker.make_experience(**data, do_sample=True, max_length=16)
**data, do_sample=True, max_length=16, eos_token_id=50256, pad_token_id=50256
)
assert gather_and_equal(experience.sequences) assert gather_and_equal(experience.sequences)
assert gather_and_equal(experience.action_log_probs) assert gather_and_equal(experience.action_log_probs)
assert gather_and_equal(experience.values) assert gather_and_equal(experience.values)
@ -115,4 +127,4 @@ def test_experience(world_size, strategy):
if __name__ == "__main__": if __name__ == "__main__":
test_experience(2, "colossalai") test_experience(2, "colossalai-zero2")

View File

@ -14,7 +14,7 @@ from coati.models.llama import LlamaActor
from coati.models.lora import LoraLinear, convert_to_lora_module from coati.models.lora import LoraLinear, convert_to_lora_module
from coati.models.loss import GPTLMLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss from coati.models.loss import GPTLMLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
from coati.models.opt import OPTRM, OPTActor, OPTCritic from coati.models.opt import OPTRM, OPTActor, OPTCritic
from coati.models.utils import calc_action_log_probs, compute_reward, masked_mean from coati.models.utils import calc_action_log_probs, masked_mean
@pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("batch_size", [4])
@ -27,7 +27,6 @@ from coati.models.utils import calc_action_log_probs, compute_reward, masked_mea
# HACK: skip llama due to long execution time # HACK: skip llama due to long execution time
# lambda: LlamaActor(), # lambda: LlamaActor(),
lambda: OPTActor(), lambda: OPTActor(),
# lambda: ChatGLMActor(),
], ],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -43,9 +42,16 @@ from coati.models.utils import calc_action_log_probs, compute_reward, masked_mea
], ],
) )
def test_generation(actor_maker: Callable[[], Actor], batch_size: int, seq_len: int, generate_kwargs: Dict[str, Any]): def test_generation(actor_maker: Callable[[], Actor], batch_size: int, seq_len: int, generate_kwargs: Dict[str, Any]):
class MockTokenizer:
def __init__(self):
self.padding_side = "left"
self.eos_token_id = 0
self.pad_token_id = 0
actor = actor_maker() actor = actor_maker()
input_ids = torch.randint(0, 100, (batch_size, seq_len)).cuda() input_ids = torch.randint(0, 100, (batch_size, seq_len)).cuda()
sequences = generate(actor.cuda(), input_ids, **generate_kwargs) tokenizer = MockTokenizer()
sequences = generate(actor.cuda(), input_ids, tokenizer, **generate_kwargs)
assert sequences.shape == (batch_size, generate_kwargs["max_length"]) assert sequences.shape == (batch_size, generate_kwargs["max_length"])
@ -55,24 +61,12 @@ def test_utils():
assert fn_output.dim() == 0 assert fn_output.dim() == 0
assert torch.allclose(fn_output, torch.tensor(1.0)) assert torch.allclose(fn_output, torch.tensor(1.0))
batch_size = 4
num_labels = 10
fn_input = {
"r": torch.ones((batch_size,)),
"kl_coef": 1.0,
"log_probs": torch.randn((batch_size, num_labels)),
"log_probs_base": torch.randn((batch_size, num_labels)),
"action_mask": torch.randint(0, 2, (batch_size, num_labels)),
}
fn_output = compute_reward(**fn_input)
assert fn_output.shape == (batch_size,)
batch_size = 4 batch_size = 4
seq_len = 32 seq_len = 32
num_labels = 10 num_labels = 10
num_actions = 2 num_actions = 2
fn_input = { fn_input = {
"output": {"logits": torch.randn((batch_size, seq_len, num_labels))}, "logits": torch.randn((batch_size, seq_len, num_labels)),
"sequences": torch.randint(0, num_labels, (batch_size, seq_len)), "sequences": torch.randint(0, num_labels, (batch_size, seq_len)),
"num_actions": num_actions, "num_actions": num_actions,
} }
@ -135,7 +129,6 @@ def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]], b
} }
critic_input = { critic_input = {
"sequences": torch.randint(0, 100, (batch_size, seq_len)), "sequences": torch.randint(0, 100, (batch_size, seq_len)),
"action_mask": torch.randint(0, 2, (batch_size, seq_len)),
"attention_mask": torch.randint(0, 2, (batch_size, seq_len)), "attention_mask": torch.randint(0, 2, (batch_size, seq_len)),
} }
rm_input = { rm_input = {

View File

@ -24,8 +24,8 @@ if [ -z "$SFT_DATASET" ]; then
exit 1 exit 1
fi fi
if [ -z "$PROMPT_PATH" ]; then if [ -z "$PROMPT_DATASET" ]; then
echo "Please set \$PROMPT_PATH to the path to prompts csv." echo "Please set \$PROMPT_DATASET to the path to prompts csv."
exit 1 exit 1
fi fi
@ -74,11 +74,15 @@ echo "[Test]: testing sft ..."
# FIXME: This is a hack to skip tests that are not working # FIXME: This is a hack to skip tests that are not working
# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation # - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
# - llama-*: These tests can be passed locally, skipped for long execution time # - llama-*: These tests can be passed locally, skipped for long execution time
# - *-gemini: Gemini plugin does not support `from_pretrained` yet
SKIPPED_TESTS=( SKIPPED_TESTS=(
"gpt2-ddp" "gpt2-ddp"
"llama-ddp" "llama-ddp"
"llama-colossalai_gemini" "llama-colossalai_gemini"
"llama-colossalai_zero2" "llama-colossalai_zero2"
"gpt2-colossalai_gemini"
"opt-colossalai_gemini"
"bloom-colossalai_gemini"
) )
GRAD_CKPTS=('' '--grad_checkpoint') GRAD_CKPTS=('' '--grad_checkpoint')
@ -105,7 +109,7 @@ for lora_rank in '0' '4'; do
$pretrain_model --tokenizer $MODELS_DIR/$model \ $pretrain_model --tokenizer $MODELS_DIR/$model \
--model $model --strategy $strategy --lora_rank $lora_rank $grad_ckpt \ --model $model --strategy $strategy --lora_rank $lora_rank $grad_ckpt \
--dataset $SFT_DATASET --max_datasets_size 8 \ --dataset $SFT_DATASET --max_datasets_size 8 \
--max_epochs 1 --batch_size 1 --accumulation_steps 1 \ --max_epochs 1 --batch_size 1 --accumulation_steps 1 --lr 1e-8 \
--save_path $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank} --save_path $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank}
passed=$? passed=$?
if [ $passed -eq 0 ]; then if [ $passed -eq 0 ]; then
@ -125,11 +129,15 @@ echo "[Test]: testing reward model ..."
# FIXME: This is a hack to skip tests that are not working # FIXME: This is a hack to skip tests that are not working
# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation # - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
# - llama-*: These tests can be passed locally, skipped for long execution time # - llama-*: These tests can be passed locally, skipped for long execution time
# - *-gemini: Gemini plugin does not support `from_pretrained` yet
SKIPPED_TESTS=( SKIPPED_TESTS=(
"gpt2-ddp" "gpt2-ddp"
"llama-ddp" "llama-ddp"
"llama-colossalai_gemini" "llama-colossalai_gemini"
"llama-colossalai_zero2" "llama-colossalai_zero2"
"gpt2-colossalai_gemini"
"opt-colossalai_gemini"
"bloom-colossalai_gemini"
) )
LOSS_FNS=('log_sig' 'log_exp') LOSS_FNS=('log_sig' 'log_exp')
@ -157,8 +165,9 @@ for lora_rank in '0' '4'; do
echo "[Test]: $model-$strategy-$lora_rank, attempt $i" echo "[Test]: $model-$strategy-$lora_rank, attempt $i"
torchrun --standalone --nproc_per_node=4 $EXAMPLES_DIR/train_reward_model.py \ torchrun --standalone --nproc_per_node=4 $EXAMPLES_DIR/train_reward_model.py \
$pretrain_model --tokenizer $MODELS_DIR/$model \ $pretrain_model --tokenizer $MODELS_DIR/$model \
--model $model --strategy $strategy --lora_rank $lora_rank --loss_fn $loss_fn \ --dataset $dataset --subset $subset --max_datasets_size 8 \
--dataset $dataset --subset $subset --test True --batch_size 1 \ --model $model --strategy $strategy --lora_rank $lora_rank \
--loss_fn $loss_fn --batch_size 1 --lr 1e-8 \
--save_path $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt --save_path $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt
passed=$? passed=$?
if [ $passed -eq 0 ]; then if [ $passed -eq 0 ]; then
@ -178,11 +187,15 @@ echo "[Test]: testing RLHF ..."
# FIXME: This is a hack to skip tests that are not working # FIXME: This is a hack to skip tests that are not working
# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation # - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
# - llama-*: These tests can be passed locally, skipped for long execution time # - llama-*: These tests can be passed locally, skipped for long execution time
# - *-gemini: Gemini plugin does not support `from_pretrained` yet
SKIPPED_TESTS=( SKIPPED_TESTS=(
"gpt2-ddp" "gpt2-ddp"
"llama-ddp" "llama-ddp"
"llama-colossalai_gemini" "llama-colossalai_gemini"
"llama-colossalai_zero2" "llama-colossalai_zero2"
"gpt2-colossalai_gemini"
"opt-colossalai_gemini"
"bloom-colossalai_gemini"
) )
for model in ${MODELS[@]}; do for model in ${MODELS[@]}; do
@ -204,9 +217,9 @@ for model in ${MODELS[@]}; do
for i in $(seq $NUM_RETRY); do for i in $(seq $NUM_RETRY); do
echo "[Test]: $model-$strategy-$lora_rank, attempt $i" echo "[Test]: $model-$strategy-$lora_rank, attempt $i"
torchrun --standalone --nproc_per_node=4 $EXAMPLES_DIR/train_prompts.py \ torchrun --standalone --nproc_per_node=4 $EXAMPLES_DIR/train_prompts.py \
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \ --prompt_dataset $PROMPT_DATASET --pretrain_dataset $PRETRAIN_DATASET --max_datasets_size 32 \
--strategy $strategy --model $model --tokenizer $MODELS_DIR/$model \ --strategy $strategy --model $model --tokenizer $MODELS_DIR/$model \
--num_episodes 1 --num_collect_steps 1 --num_update_steps 1 \ --num_episodes 1 --num_collect_steps 1 --num_update_steps 1 --lr 1e-8 \
--experience_batch_size 2 --train_batch_size 1 --lora_rank $lora_rank \ --experience_batch_size 2 --train_batch_size 1 --lora_rank $lora_rank \
--pretrain $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank} \ --pretrain $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank} \
$rm_pretrain_model --rm_path $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt \ $rm_pretrain_model --rm_path $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt \

View File

@ -3,6 +3,7 @@ import time
import pytest import pytest
import torch import torch
from model_zoo import GPTLMLoss, get_gpt2_components
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
import colossalai import colossalai
@ -13,7 +14,6 @@ from colossalai.fx.profiler import parameter_size
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import spawn from colossalai.testing import spawn
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from model_zoo import GPTLMLoss, get_gpt2_components
def parse_args(): def parse_args():

View File

@ -3,6 +3,7 @@ import time
from functools import partial from functools import partial
import torch import torch
from model_zoo import model_builder
from torch import nn from torch import nn
from colossalai.fx import ColoTracer from colossalai.fx import ColoTracer
@ -12,7 +13,6 @@ from colossalai.legacy.pipeline.middleware.adaptor import get_fx_topology
from colossalai.legacy.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine from colossalai.legacy.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine
from colossalai.legacy.pipeline.rpc.utils import rpc_run from colossalai.legacy.pipeline.rpc.utils import rpc_run
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from model_zoo import model_builder
def parse_args(): def parse_args():