mirror of https://github.com/hpcaitech/ColossalAI
[chat]: update rm, add wandb and fix bugs (#4471)
* feat: modify forward fn of critic and reward model * feat: modify calc_action_log_probs * to: add wandb in sft and rm trainer * feat: update train_sft * feat: update train_rm * style: modify type annotation and add warning * feat: pass tokenizer to ppo trainer * to: modify trainer base and maker base * feat: add wandb in ppo trainer * feat: pass tokenizer to generate * test: update generate fn tests * test: update train tests * fix: remove action_mask * feat: remove unused code * fix: fix wrong ignore_index * fix: fix mock tokenizer * chore: update requirements * revert: modify make_experience * fix: fix inference * fix: add padding side * style: modify _on_learn_batch_end * test: use mock tokenizer * fix: use bf16 to avoid overflow * fix: fix workflow * [chat] fix gemini strategy * [chat] fix * sync: update colossalai strategy * fix: fix args and model dtype * fix: fix checkpoint test * fix: fix requirements * fix: fix missing import and wrong arg * fix: temporarily skip gemini test in stage 3 * style: apply pre-commit * fix: temporarily skip gemini test in stage 1&2 --------- Co-authored-by: Mingyan Jiang <1829166702@qq.com>pull/4766/head
parent
07c2e3d09c
commit
7b9b86441f
|
@ -49,5 +49,5 @@ jobs:
|
||||||
NCCL_SHM_DISABLE: 1
|
NCCL_SHM_DISABLE: 1
|
||||||
MAX_JOBS: 8
|
MAX_JOBS: 8
|
||||||
SFT_DATASET: /data/scratch/github_actions/chat/data.json
|
SFT_DATASET: /data/scratch/github_actions/chat/data.json
|
||||||
PROMPT_PATH: /data/scratch/github_actions/chat/prompts_en.jsonl
|
PROMPT_DATASET: /data/scratch/github_actions/chat/prompts_en.jsonl
|
||||||
PRETRAIN_DATASET: /data/scratch/github_actions/chat/alpaca_data.json
|
PRETRAIN_DATASET: /data/scratch/github_actions/chat/alpaca_data.json
|
||||||
|
|
|
@ -138,6 +138,7 @@ def main(args):
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
tokenizer.padding_side = "left"
|
||||||
|
|
||||||
(actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim))
|
(actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim))
|
||||||
|
|
||||||
|
@ -154,6 +155,7 @@ def main(args):
|
||||||
initial_model,
|
initial_model,
|
||||||
actor_optim,
|
actor_optim,
|
||||||
critic_optim,
|
critic_optim,
|
||||||
|
tokenizer=tokenizer,
|
||||||
ptx_coef=0,
|
ptx_coef=0,
|
||||||
train_batch_size=args.train_batch_size,
|
train_batch_size=args.train_batch_size,
|
||||||
offload_inference_models=args.offload_inference_models,
|
offload_inference_models=args.offload_inference_models,
|
||||||
|
@ -162,8 +164,6 @@ def main(args):
|
||||||
temperature=1.0,
|
temperature=1.0,
|
||||||
top_k=50,
|
top_k=50,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
pad_token_id=tokenizer.pad_token_id,
|
|
||||||
eos_token_id=tokenizer.eos_token_id,
|
|
||||||
callbacks=[performance_evaluator],
|
callbacks=[performance_evaluator],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
from typing import Dict, Sequence, Tuple
|
from typing import Dict, Optional, Sequence, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
|
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
|
||||||
|
@ -57,6 +57,7 @@ def _preprocess(
|
||||||
sources, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
|
sources, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert sequences_token["attention_mask"].dim() == 2, "seq2seq model should be preprocessed differently"
|
||||||
labels = copy.deepcopy(sequences_token["input_ids"])
|
labels = copy.deepcopy(sequences_token["input_ids"])
|
||||||
for i in range(labels.shape[0]):
|
for i in range(labels.shape[0]):
|
||||||
source_len = sources_token["attention_mask"][i].sum().item()
|
source_len = sources_token["attention_mask"][i].sum().item()
|
||||||
|
@ -64,9 +65,10 @@ def _preprocess(
|
||||||
if tokenizer.padding_side == "right":
|
if tokenizer.padding_side == "right":
|
||||||
# |prompt|completion|eos|pad|
|
# |prompt|completion|eos|pad|
|
||||||
labels[i][:source_len] = IGNORE_INDEX
|
labels[i][:source_len] = IGNORE_INDEX
|
||||||
|
labels[i][-pad_len:] = IGNORE_INDEX
|
||||||
elif tokenizer.padding_side == "left":
|
elif tokenizer.padding_side == "left":
|
||||||
# |pad|prompt|completion|eos|
|
# |pad|prompt|completion|eos|
|
||||||
labels[i][pad_len : pad_len + source_len] = IGNORE_INDEX
|
labels[i][: pad_len + source_len] = IGNORE_INDEX
|
||||||
else:
|
else:
|
||||||
raise RuntimeError()
|
raise RuntimeError()
|
||||||
|
|
||||||
|
@ -126,6 +128,8 @@ class SFTDataset(Dataset):
|
||||||
|
|
||||||
sources = [data["prompt"] for data in dataset]
|
sources = [data["prompt"] for data in dataset]
|
||||||
targets = [data["completion"] + tokenizer.eos_token for data in tqdm(dataset, disable=not is_rank_0())]
|
targets = [data["completion"] + tokenizer.eos_token for data in tqdm(dataset, disable=not is_rank_0())]
|
||||||
|
|
||||||
|
logger.info("Tokenizing inputs... This may take some time...")
|
||||||
if isinstance(tokenizer, ChatGLMTokenizer):
|
if isinstance(tokenizer, ChatGLMTokenizer):
|
||||||
self.input_ids, self.labels, self.attention_mask = _preprocess_chatglm(
|
self.input_ids, self.labels, self.attention_mask = _preprocess_chatglm(
|
||||||
sources, targets, tokenizer, max_length
|
sources, targets, tokenizer, max_length
|
||||||
|
@ -133,6 +137,8 @@ class SFTDataset(Dataset):
|
||||||
else:
|
else:
|
||||||
self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length)
|
self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length)
|
||||||
|
|
||||||
|
logger.info("Loaded dataset.")
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
length = self.input_ids.shape[0]
|
length = self.input_ids.shape[0]
|
||||||
return length
|
return length
|
||||||
|
@ -148,7 +154,11 @@ class SupervisedDataset(Dataset):
|
||||||
"""Dataset for supervised fine-tuning."""
|
"""Dataset for supervised fine-tuning."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, data_path: str, tokenizer: PreTrainedTokenizer, max_datasets_size: int = None, max_length: int = 512
|
self,
|
||||||
|
data_path: str,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
max_datasets_size: Optional[int] = None,
|
||||||
|
max_length: int = 512,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
logger.info("Loading data...")
|
logger.info("Loading data...")
|
||||||
|
@ -175,6 +185,8 @@ class SupervisedDataset(Dataset):
|
||||||
else:
|
else:
|
||||||
self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length)
|
self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length)
|
||||||
|
|
||||||
|
logger.info("Loaded dataset.")
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
length = self.input_ids.shape[0]
|
length = self.input_ids.shape[0]
|
||||||
return length
|
return length
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import random
|
import random
|
||||||
|
import warnings
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -30,9 +31,11 @@ class NaiveExperienceBuffer(ExperienceBuffer):
|
||||||
experience.to_device(torch.device("cpu"))
|
experience.to_device(torch.device("cpu"))
|
||||||
items = split_experience_batch(experience)
|
items = split_experience_batch(experience)
|
||||||
self.items.extend(items)
|
self.items.extend(items)
|
||||||
|
|
||||||
if self.limit > 0:
|
if self.limit > 0:
|
||||||
samples_to_remove = len(self.items) - self.limit
|
samples_to_remove = len(self.items) - self.limit
|
||||||
if samples_to_remove > 0:
|
if samples_to_remove > 0:
|
||||||
|
warnings.warn(f"Experience buffer is full. Removing {samples_to_remove} samples.")
|
||||||
self.items = self.items[samples_to_remove:]
|
self.items = self.items[samples_to_remove:]
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
|
|
|
@ -3,8 +3,7 @@ from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
from coati.models.base import Actor, Critic, RewardModel
|
||||||
from coati.models.base import Actor
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -59,16 +58,13 @@ class Experience:
|
||||||
|
|
||||||
|
|
||||||
class ExperienceMaker(ABC):
|
class ExperienceMaker(ABC):
|
||||||
def __init__(
|
def __init__(self, actor: Actor, critic: Critic, reward_model: RewardModel, initial_model: Actor) -> None:
|
||||||
self, actor: Actor, critic: nn.Module, reward_model: nn.Module, initial_model: Actor, kl_coef: float = 0.1
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.actor = actor
|
self.actor = actor
|
||||||
self.critic = critic
|
self.critic = critic
|
||||||
self.reward_model = reward_model
|
self.reward_model = reward_model
|
||||||
self.initial_model = initial_model
|
self.initial_model = initial_model
|
||||||
self.kl_coef = kl_coef
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experience:
|
def make_experience(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **generate_kwargs) -> Experience:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from coati.models.base import Actor, Critic, RewardModel
|
||||||
from coati.models.generation import generate
|
from coati.models.generation import generate
|
||||||
from coati.models.utils import calc_action_log_probs, compute_reward
|
from coati.models.utils import calc_action_log_probs, compute_reward
|
||||||
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
from .base import Experience, ExperienceMaker
|
from .base import Experience, ExperienceMaker
|
||||||
|
|
||||||
|
@ -11,6 +13,19 @@ class NaiveExperienceMaker(ExperienceMaker):
|
||||||
Naive experience maker.
|
Naive experience maker.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
actor: Actor,
|
||||||
|
critic: Critic,
|
||||||
|
reward_model: RewardModel,
|
||||||
|
initial_model: Actor,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
kl_coef: float = 0.1,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(actor, critic, reward_model, initial_model)
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.kl_coef = kl_coef
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experience:
|
def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experience:
|
||||||
self.actor.eval()
|
self.actor.eval()
|
||||||
|
@ -19,16 +34,16 @@ class NaiveExperienceMaker(ExperienceMaker):
|
||||||
self.reward_model.eval()
|
self.reward_model.eval()
|
||||||
|
|
||||||
# generate sequences
|
# generate sequences
|
||||||
sequences = generate(self.actor, input_ids, **generate_kwargs)
|
sequences = generate(self.actor, input_ids, self.tokenizer, **generate_kwargs)
|
||||||
|
|
||||||
# calculate auxiliary tensors
|
# calculate auxiliary tensors
|
||||||
attention_mask = None
|
attention_mask = None
|
||||||
pad_token_id = generate_kwargs.get("pad_token_id", None)
|
pad_token_id = self.tokenizer.pad_token_id
|
||||||
if pad_token_id is not None:
|
if pad_token_id is not None:
|
||||||
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
|
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
|
||||||
|
|
||||||
input_len = input_ids.size(1)
|
input_len = input_ids.size(1)
|
||||||
eos_token_id = generate_kwargs.get("eos_token_id", None)
|
eos_token_id = self.tokenizer.eos_token_id
|
||||||
if eos_token_id is None:
|
if eos_token_id is None:
|
||||||
action_mask = torch.ones_like(sequences, dtype=torch.bool)
|
action_mask = torch.ones_like(sequences, dtype=torch.bool)
|
||||||
else:
|
else:
|
||||||
|
@ -40,11 +55,11 @@ class NaiveExperienceMaker(ExperienceMaker):
|
||||||
action_mask = action_mask[:, -(sequences.size(1) - input_len) :]
|
action_mask = action_mask[:, -(sequences.size(1) - input_len) :]
|
||||||
num_actions = action_mask.size(1)
|
num_actions = action_mask.size(1)
|
||||||
|
|
||||||
actor_output = self.actor(sequences, attention_mask)
|
actor_output = self.actor(sequences, attention_mask)["logits"]
|
||||||
action_log_probs = calc_action_log_probs(actor_output, sequences, num_actions)
|
action_log_probs = calc_action_log_probs(actor_output, sequences, num_actions)
|
||||||
base_model_output = self.initial_model(sequences, attention_mask)
|
base_model_output = self.initial_model(sequences, attention_mask)["logits"]
|
||||||
base_action_log_probs = calc_action_log_probs(base_model_output, sequences, num_actions)
|
base_action_log_probs = calc_action_log_probs(base_model_output, sequences, num_actions)
|
||||||
value = self.critic(sequences, action_mask, attention_mask)
|
value = self.critic(sequences, attention_mask)
|
||||||
r = self.reward_model(sequences, attention_mask)
|
r = self.reward_model(sequences, attention_mask)
|
||||||
reward = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask)
|
reward = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask)
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,7 @@ class Actor(LoRAModule):
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
**model_kwargs, # HACK: `generate` method may pass more kwargs
|
**model_kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Returns model output."""
|
"""Returns model output."""
|
||||||
output = self.model(input_ids, attention_mask=attention_mask, **model_kwargs)
|
output = self.model(input_ids, attention_mask=attention_mask, **model_kwargs)
|
||||||
|
|
|
@ -1,10 +1,7 @@
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from ..lora import LoRAModule
|
from ..lora import LoRAModule
|
||||||
from ..utils import masked_mean
|
|
||||||
|
|
||||||
|
|
||||||
class Critic(LoRAModule):
|
class Critic(LoRAModule):
|
||||||
|
@ -19,37 +16,19 @@ class Critic(LoRAModule):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self, model: nn.Module, value_head: nn.Module, lora_rank: int = 0, lora_train_bias: str = "none"
|
||||||
model: nn.Module,
|
|
||||||
value_head: nn.Module,
|
|
||||||
lora_rank: int = 0,
|
|
||||||
lora_train_bias: str = "none",
|
|
||||||
use_action_mask: bool = False,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
|
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
|
||||||
self.model = model
|
self.model = model
|
||||||
self.value_head = value_head
|
self.value_head = value_head
|
||||||
self.use_action_mask = use_action_mask
|
|
||||||
self.convert_to_lora()
|
self.convert_to_lora()
|
||||||
|
|
||||||
def forward(
|
def forward(self, sequences: torch.LongTensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
||||||
self,
|
|
||||||
sequences: torch.LongTensor,
|
|
||||||
action_mask: Optional[torch.Tensor] = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
outputs = self.model(sequences, attention_mask=attention_mask)
|
outputs = self.model(sequences, attention_mask=attention_mask)
|
||||||
last_hidden_states = outputs["last_hidden_state"]
|
last_hidden_states = outputs["last_hidden_state"]
|
||||||
|
sequence_lengths = torch.max(attention_mask * torch.arange(sequences.size(1), device=sequences.device), dim=1)[
|
||||||
values = self.value_head(last_hidden_states).squeeze(-1)
|
0
|
||||||
|
]
|
||||||
if action_mask is not None and self.use_action_mask:
|
sequence_hidden_states = last_hidden_states[torch.arange(last_hidden_states.size(0)), sequence_lengths]
|
||||||
num_actions = action_mask.size(1)
|
values = self.value_head(sequence_hidden_states).squeeze(1) # ensure shape is (B, )
|
||||||
prompt_mask = attention_mask[:, :-num_actions]
|
return values
|
||||||
values = values[:, :-num_actions]
|
|
||||||
value = masked_mean(values, prompt_mask, dim=1)
|
|
||||||
return value
|
|
||||||
|
|
||||||
values = values[:, :-1]
|
|
||||||
value = values.mean(dim=1)
|
|
||||||
return value
|
|
||||||
|
|
|
@ -35,9 +35,12 @@ class RewardModel(LoRAModule):
|
||||||
else:
|
else:
|
||||||
self.value_head = nn.Linear(model.config.n_embd, 1)
|
self.value_head = nn.Linear(model.config.n_embd, 1)
|
||||||
|
|
||||||
def forward(self, sequences: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
def forward(self, sequences: torch.LongTensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
||||||
outputs = self.model(sequences, attention_mask=attention_mask)
|
outputs = self.model(sequences, attention_mask=attention_mask)
|
||||||
last_hidden_states = outputs["last_hidden_state"]
|
last_hidden_states = outputs["last_hidden_state"]
|
||||||
values = self.value_head(last_hidden_states)[:, :-1]
|
sequence_lengths = torch.max(attention_mask * torch.arange(sequences.size(1), device=sequences.device), dim=1)[
|
||||||
value = values.mean(dim=1).squeeze(1) # ensure shape is (B)
|
0
|
||||||
return value
|
]
|
||||||
|
sequence_hidden_states = last_hidden_states[torch.arange(last_hidden_states.size(0)), sequence_lengths]
|
||||||
|
values = self.value_head(sequence_hidden_states).squeeze(1) # ensure shape is (B, )
|
||||||
|
return values
|
||||||
|
|
|
@ -2,6 +2,7 @@ from typing import Any, Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
from .base import Actor
|
from .base import Actor
|
||||||
|
|
||||||
|
@ -63,8 +64,8 @@ def _sample(
|
||||||
)
|
)
|
||||||
outputs = model(**model_inputs)
|
outputs = model(**model_inputs)
|
||||||
|
|
||||||
|
# NOTE: this is correct only in left padding mode
|
||||||
next_token_logits = outputs["logits"][:, -1, :]
|
next_token_logits = outputs["logits"][:, -1, :]
|
||||||
# pre-process distribution
|
|
||||||
next_token_logits = logits_processor(input_ids, next_token_logits)
|
next_token_logits = logits_processor(input_ids, next_token_logits)
|
||||||
# sample
|
# sample
|
||||||
probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float)
|
probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float)
|
||||||
|
@ -72,8 +73,7 @@ def _sample(
|
||||||
|
|
||||||
# finished sentences should have their next token be a padding token
|
# finished sentences should have their next token be a padding token
|
||||||
if eos_token_id is not None:
|
if eos_token_id is not None:
|
||||||
if pad_token_id is None:
|
assert pad_token_id is not None, "If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
|
||||||
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
|
||||||
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
||||||
|
|
||||||
# update generated ids, model inputs for next step
|
# update generated ids, model inputs for next step
|
||||||
|
@ -96,12 +96,11 @@ def _sample(
|
||||||
def generate(
|
def generate(
|
||||||
model: Actor,
|
model: Actor,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
max_length: int,
|
max_length: int,
|
||||||
num_beams: int = 1,
|
num_beams: int = 1,
|
||||||
do_sample: bool = True,
|
do_sample: bool = True,
|
||||||
early_stopping: bool = False,
|
early_stopping: bool = False,
|
||||||
eos_token_id: Optional[int] = None,
|
|
||||||
pad_token_id: Optional[int] = None,
|
|
||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
|
@ -118,14 +117,13 @@ def generate(
|
||||||
num_beams (int, optional): number of beams. Defaults to 1.
|
num_beams (int, optional): number of beams. Defaults to 1.
|
||||||
do_sample (bool, optional): whether to do sample. Defaults to True.
|
do_sample (bool, optional): whether to do sample. Defaults to True.
|
||||||
early_stopping (bool, optional): if True, the sequence length may be smaller than max_length due to finding eos. Defaults to False.
|
early_stopping (bool, optional): if True, the sequence length may be smaller than max_length due to finding eos. Defaults to False.
|
||||||
eos_token_id (Optional[int], optional): end of sequence token id. Defaults to None.
|
|
||||||
pad_token_id (Optional[int], optional): pad token id. Defaults to None.
|
|
||||||
top_k (Optional[int], optional): the number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to None.
|
top_k (Optional[int], optional): the number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to None.
|
||||||
top_p (Optional[float], optional): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. Defaults to None.
|
top_p (Optional[float], optional): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. Defaults to None.
|
||||||
temperature (Optional[float], optional): The value used to module the next token probabilities. Defaults to None.
|
temperature (Optional[float], optional): The value used to module the next token probabilities. Defaults to None.
|
||||||
prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None.
|
prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None.
|
||||||
update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None.
|
update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
assert tokenizer.padding_side == "left", "Current generation only supports left padding."
|
||||||
is_greedy_gen_mode = (num_beams == 1) and do_sample is False
|
is_greedy_gen_mode = (num_beams == 1) and do_sample is False
|
||||||
is_sample_gen_mode = (num_beams == 1) and do_sample is True
|
is_sample_gen_mode = (num_beams == 1) and do_sample is True
|
||||||
is_beam_gen_mode = (num_beams > 1) and do_sample is False
|
is_beam_gen_mode = (num_beams > 1) and do_sample is False
|
||||||
|
@ -139,8 +137,8 @@ def generate(
|
||||||
input_ids,
|
input_ids,
|
||||||
max_length,
|
max_length,
|
||||||
early_stopping=early_stopping,
|
early_stopping=early_stopping,
|
||||||
eos_token_id=eos_token_id,
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
pad_token_id=pad_token_id,
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
|
|
|
@ -13,6 +13,7 @@ class GPTLMLoss(nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
# NOTE: default ignore_index is -100, which is equal to IGNORE_INDEX in sft_dataset.py
|
||||||
self.loss = nn.CrossEntropyLoss()
|
self.loss = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
|
@ -46,18 +46,17 @@ def _log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.
|
||||||
return log_probs_labels.squeeze(-1)
|
return log_probs_labels.squeeze(-1)
|
||||||
|
|
||||||
|
|
||||||
def calc_action_log_probs(output: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor:
|
def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor:
|
||||||
"""Calculate action log probs.
|
"""Calculate action log probs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
output (torch.Tensor): Output tensor of Actor.forward.
|
output (torch.Tensor): Output tensor of Actor.forward.logits.
|
||||||
sequences (torch.LongTensor): Input sequences.
|
sequences (torch.LongTensor): Input sequences.
|
||||||
num_actions (int): Number of actions.
|
num_actions (int): Number of actions.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: Action log probs.
|
torch.Tensor: Action log probs.
|
||||||
"""
|
"""
|
||||||
logits = output["logits"]
|
|
||||||
log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
|
log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
|
||||||
return log_probs[:, -num_actions:]
|
return log_probs[:, -num_actions:]
|
||||||
|
|
||||||
|
|
|
@ -41,13 +41,13 @@ def get_actor_from_args(model: str, pretrained: str = None, config=None, lora_ra
|
||||||
|
|
||||||
def get_critic_from_args(model: str, pretrained: str = None, config=None, lora_rank=0):
|
def get_critic_from_args(model: str, pretrained: str = None, config=None, lora_rank=0):
|
||||||
if model == "gpt2":
|
if model == "gpt2":
|
||||||
critic = GPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
|
critic = GPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config)
|
||||||
elif model == "bloom":
|
elif model == "bloom":
|
||||||
critic = BLOOMCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
|
critic = BLOOMCritic(pretrained=pretrained, lora_rank=lora_rank, config=config)
|
||||||
elif model == "opt":
|
elif model == "opt":
|
||||||
critic = OPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
|
critic = OPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config)
|
||||||
elif model == "llama":
|
elif model == "llama":
|
||||||
critic = LlamaCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
|
critic = LlamaCritic(pretrained=pretrained, lora_rank=lora_rank, config=config)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unsupported reward model "{model}"')
|
raise ValueError(f'Unsupported reward model "{model}"')
|
||||||
return critic
|
return critic
|
||||||
|
|
|
@ -7,11 +7,10 @@ import tqdm
|
||||||
from coati.experience_buffer import NaiveExperienceBuffer
|
from coati.experience_buffer import NaiveExperienceBuffer
|
||||||
from coati.experience_maker import Experience
|
from coati.experience_maker import Experience
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
from .callbacks import Callback
|
from .callbacks import Callback
|
||||||
from .strategies import Strategy
|
from .strategies import Strategy
|
||||||
from .utils import CycledDataLoader, is_rank_0
|
from .utils import is_rank_0
|
||||||
|
|
||||||
|
|
||||||
class SLTrainer(ABC):
|
class SLTrainer(ABC):
|
||||||
|
@ -47,11 +46,11 @@ class SLTrainer(ABC):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def _before_fit(self):
|
def _before_fit(self):
|
||||||
self.no_epoch_bar = False
|
raise NotImplementedError()
|
||||||
|
|
||||||
def fit(self, *args, **kwargs):
|
def fit(self, *args, **kwargs):
|
||||||
self._before_fit(*args, **kwargs)
|
self._before_fit(*args, **kwargs)
|
||||||
for epoch in tqdm.trange(self.max_epochs, desc="Epochs", disable=not is_rank_0() or self.no_epoch_bar):
|
for epoch in tqdm.trange(self.max_epochs, desc="Epochs", disable=not is_rank_0()):
|
||||||
self._train(epoch)
|
self._train(epoch)
|
||||||
self._eval(epoch)
|
self._eval(epoch)
|
||||||
|
|
||||||
|
@ -123,9 +122,9 @@ class OnPolicyTrainer(ABC):
|
||||||
for callback in self.callbacks:
|
for callback in self.callbacks:
|
||||||
callback.on_learn_batch_start()
|
callback.on_learn_batch_start()
|
||||||
|
|
||||||
def _on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
|
def _on_learn_batch_end(self, experience: Experience) -> None:
|
||||||
for callback in self.callbacks:
|
for callback in self.callbacks:
|
||||||
callback.on_learn_batch_end(metrics, experience)
|
callback.on_learn_batch_end(experience)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _make_experience(self, collect_step: int):
|
def _make_experience(self, collect_step: int):
|
||||||
|
@ -153,27 +152,26 @@ class OnPolicyTrainer(ABC):
|
||||||
self._learn(update_step)
|
self._learn(update_step)
|
||||||
self._on_learn_epoch_end(update_step)
|
self._on_learn_epoch_end(update_step)
|
||||||
|
|
||||||
|
def _before_fit(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def fit(
|
def fit(
|
||||||
self,
|
self,
|
||||||
prompt_dataloader: DataLoader,
|
|
||||||
pretrain_dataloader: DataLoader,
|
|
||||||
num_episodes: int,
|
num_episodes: int,
|
||||||
num_collect_steps: int,
|
num_collect_steps: int,
|
||||||
num_update_steps: int,
|
num_update_steps: int,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
The main training loop of on-policy rl trainers.
|
The main training loop of on-policy rl trainers.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt_dataloader (DataLoader): the dataloader to use for prompt data
|
|
||||||
pretrain_dataloader (DataLoader): the dataloader to use for pretrain data
|
|
||||||
num_episodes (int): the number of episodes to train
|
num_episodes (int): the number of episodes to train
|
||||||
num_collect_steps (int): the number of collect steps per episode
|
num_collect_steps (int): the number of collect steps per episode
|
||||||
num_update_steps (int): the number of update steps per episode
|
num_update_steps (int): the number of update steps per episode
|
||||||
"""
|
"""
|
||||||
self.prompt_dataloader = CycledDataLoader(prompt_dataloader)
|
self._before_fit(*args, **kwargs)
|
||||||
self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader)
|
|
||||||
|
|
||||||
with self._fit_ctx():
|
with self._fit_ctx():
|
||||||
for episode in tqdm.trange(num_episodes, desc="Episodes", disable=not is_rank_0()):
|
for episode in tqdm.trange(num_episodes, desc="Episodes", disable=not is_rank_0()):
|
||||||
with self._episode_ctx(episode):
|
with self._episode_ctx(episode):
|
||||||
|
|
|
@ -35,5 +35,5 @@ class Callback(ABC):
|
||||||
def on_learn_batch_start(self) -> None:
|
def on_learn_batch_start(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
|
def on_learn_batch_end(self, experience: Experience) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -137,7 +137,7 @@ class PerformanceEvaluator(Callback):
|
||||||
return
|
return
|
||||||
self.learn_timer.start()
|
self.learn_timer.start()
|
||||||
|
|
||||||
def on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
|
def on_learn_batch_end(self, experience: Experience) -> None:
|
||||||
if self.disable:
|
if self.disable:
|
||||||
return
|
return
|
||||||
self.learn_timer.end()
|
self.learn_timer.end()
|
||||||
|
|
|
@ -1,27 +1,26 @@
|
||||||
from typing import Dict, List
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
from coati.experience_buffer import NaiveExperienceBuffer
|
from coati.experience_buffer import NaiveExperienceBuffer
|
||||||
from coati.experience_maker import Experience, NaiveExperienceMaker
|
from coati.experience_maker import Experience, NaiveExperienceMaker
|
||||||
from coati.models.base import Actor, Critic, get_base_model
|
from coati.models.base import Actor, Critic, RewardModel, get_base_model
|
||||||
from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss
|
from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss
|
||||||
from coati.models.utils import calc_action_log_probs
|
from coati.models.utils import calc_action_log_probs
|
||||||
from torch import Tensor
|
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.utils.data import DistributedSampler
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
from .base import OnPolicyTrainer
|
from .base import OnPolicyTrainer
|
||||||
from .callbacks import Callback
|
from .callbacks import Callback
|
||||||
from .strategies import GeminiStrategy, Strategy
|
from .strategies import GeminiStrategy, Strategy
|
||||||
from .utils import is_rank_0, to_device
|
from .utils import CycledDataLoader, is_rank_0, to_device
|
||||||
|
|
||||||
|
|
||||||
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> Dict:
|
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> Dict:
|
||||||
unwrapper_model = strategy.unwrap_model(actor)
|
unwrapped_model = strategy.unwrap_model(actor)
|
||||||
hf_model = get_base_model(unwrapper_model)
|
hf_model = get_base_model(unwrapped_model)
|
||||||
new_kwargs = {**generate_kwargs}
|
new_kwargs = {**generate_kwargs}
|
||||||
# use huggingface models method directly
|
# use huggingface models method directly
|
||||||
if "prepare_inputs_fn" not in generate_kwargs and hasattr(hf_model, "prepare_inputs_for_generation"):
|
if "prepare_inputs_fn" not in generate_kwargs and hasattr(hf_model, "prepare_inputs_for_generation"):
|
||||||
|
@ -41,7 +40,7 @@ class PPOTrainer(OnPolicyTrainer):
|
||||||
strategy (Strategy): the strategy to use for training
|
strategy (Strategy): the strategy to use for training
|
||||||
actor (Actor): the actor model in ppo algorithm
|
actor (Actor): the actor model in ppo algorithm
|
||||||
critic (Critic): the critic model in ppo algorithm
|
critic (Critic): the critic model in ppo algorithm
|
||||||
reward_model (nn.Module): the reward model in rlhf algorithm to make reward of sentences
|
reward_model (RewardModel): the reward model in rlhf algorithm to make reward of sentences
|
||||||
initial_model (Actor): the initial model in rlhf algorithm to generate reference logics to limit the update of actor
|
initial_model (Actor): the initial model in rlhf algorithm to generate reference logics to limit the update of actor
|
||||||
actor_optim (Optimizer): the optimizer to use for actor model
|
actor_optim (Optimizer): the optimizer to use for actor model
|
||||||
critic_optim (Optimizer): the optimizer to use for critic model
|
critic_optim (Optimizer): the optimizer to use for critic model
|
||||||
|
@ -65,10 +64,11 @@ class PPOTrainer(OnPolicyTrainer):
|
||||||
strategy: Strategy,
|
strategy: Strategy,
|
||||||
actor: Actor,
|
actor: Actor,
|
||||||
critic: Critic,
|
critic: Critic,
|
||||||
reward_model: nn.Module,
|
reward_model: RewardModel,
|
||||||
initial_model: Actor,
|
initial_model: Actor,
|
||||||
actor_optim: Optimizer,
|
actor_optim: Optimizer,
|
||||||
critic_optim: Optimizer,
|
critic_optim: Optimizer,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
kl_coef: float = 0.1,
|
kl_coef: float = 0.1,
|
||||||
ptx_coef: float = 0.9,
|
ptx_coef: float = 0.9,
|
||||||
train_batch_size: int = 8,
|
train_batch_size: int = 8,
|
||||||
|
@ -90,11 +90,11 @@ class PPOTrainer(OnPolicyTrainer):
|
||||||
super().__init__(strategy, data_buffer, sample_buffer, dataloader_pin_memory, callbacks)
|
super().__init__(strategy, data_buffer, sample_buffer, dataloader_pin_memory, callbacks)
|
||||||
|
|
||||||
self.generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
|
self.generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
|
||||||
self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
|
self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, tokenizer, kl_coef)
|
||||||
self.offload_inference_models = offload_inference_models
|
|
||||||
|
|
||||||
self.actor = actor
|
self.actor = actor
|
||||||
self.critic = critic
|
self.critic = critic
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
self.actor_loss_fn = PolicyLoss(eps_clip)
|
self.actor_loss_fn = PolicyLoss(eps_clip)
|
||||||
self.critic_loss_fn = ValueLoss(value_clip)
|
self.critic_loss_fn = ValueLoss(value_clip)
|
||||||
|
@ -104,58 +104,81 @@ class PPOTrainer(OnPolicyTrainer):
|
||||||
self.actor_optim = actor_optim
|
self.actor_optim = actor_optim
|
||||||
self.critic_optim = critic_optim
|
self.critic_optim = critic_optim
|
||||||
|
|
||||||
|
self.offload_inference_models = offload_inference_models
|
||||||
self.device = get_current_device()
|
self.device = get_current_device()
|
||||||
|
|
||||||
|
def _before_fit(
|
||||||
|
self,
|
||||||
|
prompt_dataloader: DataLoader,
|
||||||
|
pretrain_dataloader: DataLoader,
|
||||||
|
log_dir: Optional[str] = None,
|
||||||
|
use_wandb: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
prompt_dataloader (DataLoader): the dataloader to use for prompt data
|
||||||
|
pretrain_dataloader (DataLoader): the dataloader to use for pretrain data
|
||||||
|
"""
|
||||||
|
self.prompt_dataloader = CycledDataLoader(prompt_dataloader)
|
||||||
|
self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader)
|
||||||
|
|
||||||
|
self.writer = None
|
||||||
|
if use_wandb and is_rank_0():
|
||||||
|
assert log_dir is not None, "log_dir must be provided when use_wandb is True"
|
||||||
|
import wandb
|
||||||
|
|
||||||
|
wandb.init(project="Coati-ppo", sync_tensorboard=True)
|
||||||
|
if log_dir is not None and is_rank_0():
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
log_dir = os.path.join(log_dir, "ppo")
|
||||||
|
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
|
||||||
|
self.writer = SummaryWriter(log_dir=log_dir)
|
||||||
|
|
||||||
def _make_experience(self, collect_step: int) -> Experience:
|
def _make_experience(self, collect_step: int) -> Experience:
|
||||||
prompts = self.prompt_dataloader.next()
|
prompts = self.prompt_dataloader.next()
|
||||||
if self.offload_inference_models:
|
if self.offload_inference_models:
|
||||||
# TODO(ver217): this may be controlled by strategy if they are prepared by strategy
|
# TODO(ver217): this may be controlled by strategy if they are prepared by strategy
|
||||||
self.experience_maker.initial_model.to(self.device)
|
self.experience_maker.initial_model.to(self.device)
|
||||||
self.experience_maker.reward_model.to(self.device)
|
self.experience_maker.reward_model.to(self.device)
|
||||||
if isinstance(prompts, Tensor):
|
assert isinstance(prompts, dict), f'Unsupported input type "{type(prompts)}"'
|
||||||
return self.experience_maker.make_experience(prompts, **self.generate_kwargs)
|
return self.experience_maker.make_experience(**prompts, **self.generate_kwargs)
|
||||||
elif isinstance(prompts, dict):
|
|
||||||
return self.experience_maker.make_experience(**prompts, **self.generate_kwargs)
|
|
||||||
else:
|
|
||||||
raise ValueError(f'Unsupported input type "{type(prompts)}"')
|
|
||||||
|
|
||||||
def _training_step(self, experience: Experience) -> Dict[str, float]:
|
def _training_step(self, experience: Experience):
|
||||||
self.actor.train()
|
self.actor.train()
|
||||||
self.critic.train()
|
self.critic.train()
|
||||||
# policy loss
|
# policy loss
|
||||||
num_actions = experience.action_mask.size(1)
|
num_actions = experience.action_log_probs.size(1)
|
||||||
actor_output = self.actor(experience.sequences, attention_mask=experience.attention_mask)
|
actor_logits = self.actor(experience.sequences, experience.attention_mask)["logits"]
|
||||||
action_log_probs = calc_action_log_probs(actor_output, experience.sequences, num_actions)
|
action_log_probs = calc_action_log_probs(actor_logits, experience.sequences, num_actions)
|
||||||
actor_loss = self.actor_loss_fn(
|
actor_loss = self.actor_loss_fn(
|
||||||
action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
|
action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
|
||||||
)
|
)
|
||||||
|
actor_loss = (1 - self.ptx_coef) * actor_loss
|
||||||
|
self.strategy.backward(actor_loss, self.actor, self.actor_optim)
|
||||||
|
|
||||||
# ptx loss
|
# ptx loss
|
||||||
if self.ptx_coef != 0:
|
if self.ptx_coef != 0:
|
||||||
batch = self.pretrain_dataloader.next()
|
batch = self.pretrain_dataloader.next()
|
||||||
batch = to_device(batch, self.device)
|
batch = to_device(batch, self.device)
|
||||||
ptx_log_probs = self.actor(batch["input_ids"], attention_mask=batch["attention_mask"])["logits"]
|
ptx_log_probs = self.actor(batch["input_ids"], batch["attention_mask"])["logits"]
|
||||||
ptx_loss = self.ptx_loss_fn(ptx_log_probs, batch["labels"])
|
ptx_loss = self.ptx_coef * self.ptx_loss_fn(ptx_log_probs, batch["labels"])
|
||||||
actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef)
|
self.strategy.backward(ptx_loss, self.actor, self.actor_optim)
|
||||||
|
|
||||||
self.strategy.backward(actor_loss, self.actor, self.actor_optim)
|
|
||||||
self.strategy.optimizer_step(self.actor_optim)
|
self.strategy.optimizer_step(self.actor_optim)
|
||||||
self.actor_optim.zero_grad()
|
self.actor_optim.zero_grad()
|
||||||
|
|
||||||
# value loss
|
# value loss
|
||||||
values = self.critic(
|
values = self.critic(experience.sequences, attention_mask=experience.attention_mask)
|
||||||
experience.sequences, action_mask=experience.action_mask, attention_mask=experience.attention_mask
|
critic_loss = self.critic_loss_fn(values, experience.values, experience.reward)
|
||||||
)
|
|
||||||
critic_loss = self.critic_loss_fn(
|
|
||||||
values, experience.values, experience.reward, action_mask=experience.action_mask
|
|
||||||
)
|
|
||||||
critic_loss = critic_loss * self.vf_coef
|
critic_loss = critic_loss * self.vf_coef
|
||||||
self.strategy.backward(critic_loss, self.critic, self.critic_optim)
|
self.strategy.backward(critic_loss, self.critic, self.critic_optim)
|
||||||
self.strategy.optimizer_step(self.critic_optim)
|
self.strategy.optimizer_step(self.critic_optim)
|
||||||
self.critic_optim.zero_grad()
|
self.critic_optim.zero_grad()
|
||||||
|
|
||||||
return {"reward": experience.reward.mean().item()}
|
|
||||||
|
|
||||||
def _learn(self, update_step: int):
|
def _learn(self, update_step: int):
|
||||||
if self.offload_inference_models:
|
if self.offload_inference_models:
|
||||||
self.experience_maker.initial_model.to("cpu")
|
self.experience_maker.initial_model.to("cpu")
|
||||||
|
@ -166,8 +189,8 @@ class PPOTrainer(OnPolicyTrainer):
|
||||||
experience = self.data_buffer.sample()
|
experience = self.data_buffer.sample()
|
||||||
self._on_learn_batch_start()
|
self._on_learn_batch_start()
|
||||||
experience.to_device(self.device)
|
experience.to_device(self.device)
|
||||||
metrics = self._training_step(experience)
|
self._training_step(experience)
|
||||||
self._on_learn_batch_end(metrics, experience)
|
self._on_learn_batch_end(experience)
|
||||||
else:
|
else:
|
||||||
if isinstance(self.dataloader.sampler, DistributedSampler):
|
if isinstance(self.dataloader.sampler, DistributedSampler):
|
||||||
self.dataloader.sampler.set_epoch(update_step)
|
self.dataloader.sampler.set_epoch(update_step)
|
||||||
|
@ -175,6 +198,5 @@ class PPOTrainer(OnPolicyTrainer):
|
||||||
for experience in pbar:
|
for experience in pbar:
|
||||||
self._on_learn_batch_start()
|
self._on_learn_batch_start()
|
||||||
experience.to_device(self.device)
|
experience.to_device(self.device)
|
||||||
metrics = self._training_step(experience)
|
self._training_step(experience)
|
||||||
self._on_learn_batch_end(metrics, experience)
|
self._on_learn_batch_end(experience)
|
||||||
pbar.set_postfix(metrics)
|
|
||||||
|
|
|
@ -1,7 +1,5 @@
|
||||||
from datetime import datetime
|
from typing import Callable, Optional
|
||||||
from typing import Callable
|
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
@ -40,10 +38,12 @@ class RewardModelTrainer(SLTrainer):
|
||||||
self.loss_fn = loss_fn
|
self.loss_fn = loss_fn
|
||||||
self.scheduler = lr_scheduler
|
self.scheduler = lr_scheduler
|
||||||
|
|
||||||
|
self.num_train_step = 0
|
||||||
|
|
||||||
def _eval(self, epoch):
|
def _eval(self, epoch):
|
||||||
if self.eval_dataloader is not None:
|
if self.eval_dataloader is not None:
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
dist, on, cnt = 0, 0, 0
|
dist, num_correct, num_samples = 0, 0, 0
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for chosen_ids, c_mask, reject_ids, r_mask in self.eval_dataloader:
|
for chosen_ids, c_mask, reject_ids, r_mask in self.eval_dataloader:
|
||||||
chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
|
chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
|
||||||
|
@ -52,27 +52,21 @@ class RewardModelTrainer(SLTrainer):
|
||||||
r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
|
r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
|
||||||
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
|
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
|
||||||
reject_reward = self.model(reject_ids, attention_mask=r_mask)
|
reject_reward = self.model(reject_ids, attention_mask=r_mask)
|
||||||
for i in range(len(chosen_reward)):
|
num_samples += chosen_ids.size(0)
|
||||||
cnt += 1
|
num_correct += (chosen_reward > reject_reward).sum().item()
|
||||||
if chosen_reward[i] > reject_reward[i]:
|
|
||||||
on += 1
|
|
||||||
dist += (chosen_reward - reject_reward).mean().item()
|
dist += (chosen_reward - reject_reward).mean().item()
|
||||||
self.dist = dist / len(self.eval_dataloader)
|
self.dist = dist / len(self.eval_dataloader)
|
||||||
self.acc = on / cnt
|
self.acc = num_correct / num_samples
|
||||||
|
|
||||||
if is_rank_0():
|
if self.writer:
|
||||||
log = pd.DataFrame(
|
self.writer.add_scalar("eval/dist", self.dist, epoch)
|
||||||
[[(epoch + 1) * len(self.train_dataloader), self.loss.item(), self.dist, self.acc]],
|
self.writer.add_scalar("eval/acc", self.acc, epoch)
|
||||||
columns=["step", "loss", "dist", "acc"],
|
|
||||||
)
|
|
||||||
log.to_csv("log.csv", mode="a", header=False, index=False)
|
|
||||||
|
|
||||||
def _train(self, epoch):
|
def _train(self, epoch):
|
||||||
self.model.train()
|
self.model.train()
|
||||||
step_bar = tqdm.trange(
|
step_bar = tqdm.trange(
|
||||||
len(self.train_dataloader), desc="Train step of epoch %d" % epoch, disable=not is_rank_0()
|
len(self.train_dataloader), desc=f"Epoch {epoch + 1}/{self.max_epochs}", disable=not is_rank_0()
|
||||||
)
|
)
|
||||||
cnt = 0
|
|
||||||
for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader:
|
for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader:
|
||||||
chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
|
chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
|
||||||
c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
|
c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
|
||||||
|
@ -80,26 +74,50 @@ class RewardModelTrainer(SLTrainer):
|
||||||
r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
|
r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
|
||||||
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
|
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
|
||||||
reject_reward = self.model(reject_ids, attention_mask=r_mask)
|
reject_reward = self.model(reject_ids, attention_mask=r_mask)
|
||||||
self.loss = self.loss_fn(chosen_reward, reject_reward)
|
loss = self.loss_fn(chosen_reward, reject_reward)
|
||||||
self.strategy.backward(self.loss, self.model, self.optimizer)
|
self.strategy.backward(loss, self.model, self.optimizer)
|
||||||
self.strategy.optimizer_step(self.optimizer)
|
self.strategy.optimizer_step(self.optimizer)
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
cnt += 1
|
if self.writer:
|
||||||
if cnt % 100 == 0:
|
self.writer.add_scalar("train/loss", loss.item(), self.num_train_step)
|
||||||
|
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
|
||||||
|
self.writer.add_scalar("train/dist", (chosen_reward - reject_reward).mean().item(), self.num_train_step)
|
||||||
|
self.writer.add_scalar(
|
||||||
|
"train/acc", (chosen_reward > reject_reward).float().mean().item(), self.num_train_step
|
||||||
|
)
|
||||||
|
self.num_train_step += 1
|
||||||
|
if self.num_train_step % 100 == 0:
|
||||||
self.scheduler.step()
|
self.scheduler.step()
|
||||||
step_bar.update()
|
step_bar.update()
|
||||||
step_bar.close()
|
step_bar.close()
|
||||||
|
|
||||||
def _before_fit(self, train_dataloader: DataLoader, valid_dataloader: DataLoader, eval_dataloader: DataLoader):
|
def _before_fit(
|
||||||
|
self,
|
||||||
|
train_dataloader: DataLoader,
|
||||||
|
eval_dataloader: DataLoader,
|
||||||
|
log_dir: Optional[str] = None,
|
||||||
|
use_wandb: bool = False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
train_dataloader (DataLoader): the dataloader to use for training
|
train_dataloader (DataLoader): the dataloader to use for training
|
||||||
valid_dataloader (DataLoader): the dataloader to use for validation
|
|
||||||
eval_dataloader (DataLoader): the dataloader to use for evaluation
|
eval_dataloader (DataLoader): the dataloader to use for evaluation
|
||||||
"""
|
"""
|
||||||
super()._before_fit()
|
|
||||||
self.datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
|
||||||
|
|
||||||
self.train_dataloader = train_dataloader
|
self.train_dataloader = train_dataloader
|
||||||
self.valid_dataloader = valid_dataloader
|
|
||||||
self.eval_dataloader = eval_dataloader
|
self.eval_dataloader = eval_dataloader
|
||||||
|
|
||||||
|
self.writer = None
|
||||||
|
if use_wandb and is_rank_0():
|
||||||
|
assert log_dir is not None, "log_dir must be provided when use_wandb is True"
|
||||||
|
import wandb
|
||||||
|
|
||||||
|
wandb.init(project="Coati-rm", sync_tensorboard=True)
|
||||||
|
if log_dir is not None and is_rank_0():
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
log_dir = os.path.join(log_dir, "rm")
|
||||||
|
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
|
||||||
|
self.writer = SummaryWriter(log_dir=log_dir)
|
||||||
|
|
|
@ -1,10 +1,8 @@
|
||||||
import time
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import tqdm
|
import tqdm
|
||||||
import wandb
|
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import _LRScheduler
|
from torch.optim.lr_scheduler import _LRScheduler
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
@ -48,38 +46,34 @@ class SFTTrainer(SLTrainer):
|
||||||
self.accumulation_steps = accumulation_steps
|
self.accumulation_steps = accumulation_steps
|
||||||
self.scheduler = lr_scheduler
|
self.scheduler = lr_scheduler
|
||||||
|
|
||||||
|
self.num_train_step = 0
|
||||||
|
self.num_eval_step = 0
|
||||||
|
|
||||||
def _train(self, epoch: int):
|
def _train(self, epoch: int):
|
||||||
self.model.train()
|
self.model.train()
|
||||||
for batch_id, batch in enumerate(self.train_dataloader):
|
step_bar = tqdm.trange(
|
||||||
|
len(self.train_dataloader) // self.accumulation_steps,
|
||||||
|
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
|
||||||
|
disable=not is_rank_0(),
|
||||||
|
)
|
||||||
|
for i, batch in enumerate(self.train_dataloader):
|
||||||
batch = to_device(batch, torch.cuda.current_device())
|
batch = to_device(batch, torch.cuda.current_device())
|
||||||
if "attention_mask" in batch:
|
outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
|
||||||
outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
|
loss = outputs.loss / self.accumulation_steps
|
||||||
else:
|
|
||||||
outputs = self.model(batch["input_ids"], labels=batch["labels"])
|
|
||||||
|
|
||||||
loss = outputs.loss
|
|
||||||
loss = loss / self.accumulation_steps
|
|
||||||
|
|
||||||
self.strategy.backward(loss, self.model, self.optimizer)
|
|
||||||
|
|
||||||
self.total_loss += loss.item()
|
self.total_loss += loss.item()
|
||||||
|
self.strategy.backward(loss, self.model, self.optimizer)
|
||||||
# gradient accumulation
|
# gradient accumulation
|
||||||
if (batch_id + 1) % self.accumulation_steps == 0:
|
if (i + 1) % self.accumulation_steps == 0:
|
||||||
self.strategy.optimizer_step(self.optimizer)
|
self.strategy.optimizer_step(self.optimizer)
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
self.scheduler.step()
|
self.scheduler.step()
|
||||||
if is_rank_0() and self.use_wandb:
|
if self.writer:
|
||||||
wandb.log(
|
self.writer.add_scalar("train/loss", self.total_loss, self.num_train_step)
|
||||||
{
|
self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step)
|
||||||
"loss": self.total_loss / self.accumulation_steps,
|
self.num_train_step += 1
|
||||||
"lr": self.scheduler.get_last_lr()[0],
|
|
||||||
"epoch": epoch,
|
|
||||||
"batch_id": batch_id,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
self.total_loss = 0
|
self.total_loss = 0
|
||||||
self.step_bar.update()
|
step_bar.update()
|
||||||
|
step_bar.close()
|
||||||
|
|
||||||
def _eval(self, epoch: int):
|
def _eval(self, epoch: int):
|
||||||
if self.eval_dataloader is not None:
|
if self.eval_dataloader is not None:
|
||||||
|
@ -91,20 +85,21 @@ class SFTTrainer(SLTrainer):
|
||||||
outputs = self.model(
|
outputs = self.model(
|
||||||
batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]
|
batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]
|
||||||
)
|
)
|
||||||
loss = outputs.loss
|
loss_sum += outputs.loss.item()
|
||||||
|
|
||||||
loss_sum += loss.item()
|
|
||||||
num_seen += batch["input_ids"].size(0)
|
num_seen += batch["input_ids"].size(0)
|
||||||
|
|
||||||
loss_mean = loss_sum / num_seen
|
loss_mean = loss_sum / num_seen
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
self.logger.info(f"Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}")
|
self.logger.info(f"Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}")
|
||||||
|
if self.writer:
|
||||||
|
self.writer.add_scalar("eval/loss", loss_mean, self.num_eval_step)
|
||||||
|
self.num_eval_step += 1
|
||||||
|
|
||||||
def _before_fit(
|
def _before_fit(
|
||||||
self,
|
self,
|
||||||
train_dataloader: DataLoader,
|
train_dataloader: DataLoader,
|
||||||
eval_dataloader: Optional[DataLoader] = None,
|
eval_dataloader: Optional[DataLoader] = None,
|
||||||
logger: Optional[DistributedLogger] = None,
|
logger: Optional[DistributedLogger] = None,
|
||||||
|
log_dir: Optional[str] = None,
|
||||||
use_wandb: bool = False,
|
use_wandb: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -116,15 +111,20 @@ class SFTTrainer(SLTrainer):
|
||||||
self.eval_dataloader = eval_dataloader
|
self.eval_dataloader = eval_dataloader
|
||||||
|
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.use_wandb = use_wandb
|
self.writer = None
|
||||||
if use_wandb:
|
if use_wandb and is_rank_0():
|
||||||
wandb.init(project="Coati", name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
assert log_dir is not None, "log_dir must be provided when use_wandb is True"
|
||||||
wandb.watch(self.model)
|
import wandb
|
||||||
|
|
||||||
|
wandb.init(project="Coati-sft", sync_tensorboard=True)
|
||||||
|
if log_dir is not None and is_rank_0():
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
log_dir = os.path.join(log_dir, "sft")
|
||||||
|
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
|
||||||
|
self.writer = SummaryWriter(log_dir=log_dir)
|
||||||
|
|
||||||
self.total_loss = 0
|
self.total_loss = 0
|
||||||
self.no_epoch_bar = True
|
|
||||||
self.step_bar = tqdm.trange(
|
|
||||||
len(self.train_dataloader) // self.accumulation_steps * self.max_epochs,
|
|
||||||
desc=f"steps",
|
|
||||||
disable=not is_rank_0(),
|
|
||||||
)
|
|
||||||
|
|
|
@ -1,17 +1,13 @@
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin
|
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin
|
||||||
from colossalai.booster.plugin.gemini_plugin import GeminiModel
|
|
||||||
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
|
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
|
||||||
from colossalai.tensor import ProcessGroup, ShardSpec
|
from colossalai.lazy.lazy_init import LazyInitContext
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from colossalai.zero import ColoInitContext
|
|
||||||
from colossalai.zero.gemini.gemini_ddp import GeminiDDP
|
from colossalai.zero.gemini.gemini_ddp import GeminiDDP
|
||||||
|
|
||||||
from .ddp import DDPStrategy
|
from .ddp import DDPStrategy
|
||||||
|
@ -65,14 +61,11 @@ class LowLevelZeroStrategy(DDPStrategy):
|
||||||
assert precision in ("fp32", "fp16"), f'Unsupported precision "{precision}"'
|
assert precision in ("fp32", "fp16"), f'Unsupported precision "{precision}"'
|
||||||
|
|
||||||
plugin_initializer = lambda: LowLevelZeroPlugin(
|
plugin_initializer = lambda: LowLevelZeroPlugin(
|
||||||
# zero_config
|
|
||||||
stage=stage,
|
stage=stage,
|
||||||
precision=precision,
|
precision=precision,
|
||||||
# zero_optim_config
|
|
||||||
reduce_bucket_size_in_m=reduce_bucket_size,
|
reduce_bucket_size_in_m=reduce_bucket_size,
|
||||||
overlap_communication=overlap_communication,
|
overlap_communication=overlap_communication,
|
||||||
cpu_offload=(placement_policy == "cpu"),
|
cpu_offload=(placement_policy == "cpu"),
|
||||||
# optim_config
|
|
||||||
initial_scale=initial_scale,
|
initial_scale=initial_scale,
|
||||||
growth_factor=growth_factor,
|
growth_factor=growth_factor,
|
||||||
backoff_factor=backoff_factor,
|
backoff_factor=backoff_factor,
|
||||||
|
@ -136,7 +129,7 @@ class GeminiStrategy(DDPStrategy):
|
||||||
self,
|
self,
|
||||||
seed: int = 42,
|
seed: int = 42,
|
||||||
shard_init: bool = False, # only for stage 3
|
shard_init: bool = False, # only for stage 3
|
||||||
placement_policy: str = "cuda",
|
placement_policy: str = "auto",
|
||||||
pin_memory: bool = True, # only for stage 3
|
pin_memory: bool = True, # only for stage 3
|
||||||
force_outputs_fp32: bool = False, # only for stage 3
|
force_outputs_fp32: bool = False, # only for stage 3
|
||||||
search_range_m: int = 32, # only for stage 3
|
search_range_m: int = 32, # only for stage 3
|
||||||
|
@ -153,8 +146,6 @@ class GeminiStrategy(DDPStrategy):
|
||||||
max_norm: float = 0.0,
|
max_norm: float = 0.0,
|
||||||
norm_type: float = 2.0,
|
norm_type: float = 2.0,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert placement_policy in ("cpu", "cuda"), f'Unsupported placement policy "{placement_policy}"'
|
|
||||||
|
|
||||||
# TODO(ver217): support shard_init when using from_pretrained()
|
# TODO(ver217): support shard_init when using from_pretrained()
|
||||||
if shard_init:
|
if shard_init:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
@ -167,8 +158,7 @@ class GeminiStrategy(DDPStrategy):
|
||||||
|
|
||||||
# NOTE: dist should be initialized before calling get_current_device()
|
# NOTE: dist should be initialized before calling get_current_device()
|
||||||
plugin_initializer = lambda: GeminiPlugin(
|
plugin_initializer = lambda: GeminiPlugin(
|
||||||
# gemini_config
|
chunk_init_device=get_current_device(),
|
||||||
device=get_current_device(),
|
|
||||||
placement_policy=placement_policy,
|
placement_policy=placement_policy,
|
||||||
precision="fp16",
|
precision="fp16",
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
|
@ -177,9 +167,7 @@ class GeminiStrategy(DDPStrategy):
|
||||||
search_range_m=search_range_m,
|
search_range_m=search_range_m,
|
||||||
hidden_dim=hidden_dim,
|
hidden_dim=hidden_dim,
|
||||||
min_chunk_size_m=min_chunk_size_m,
|
min_chunk_size_m=min_chunk_size_m,
|
||||||
# zero_optim_config
|
|
||||||
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
|
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
|
||||||
# optim_config
|
|
||||||
initial_scale=initial_scale,
|
initial_scale=initial_scale,
|
||||||
growth_factor=growth_factor,
|
growth_factor=growth_factor,
|
||||||
backoff_factor=backoff_factor,
|
backoff_factor=backoff_factor,
|
||||||
|
@ -200,15 +188,8 @@ class GeminiStrategy(DDPStrategy):
|
||||||
colossalai.launch_from_torch({}, seed=self.seed)
|
colossalai.launch_from_torch({}, seed=self.seed)
|
||||||
|
|
||||||
def model_init_context(self):
|
def model_init_context(self):
|
||||||
world_size = dist.get_world_size()
|
return LazyInitContext(default_device=get_current_device())
|
||||||
shard_pg = ProcessGroup(tp_degree=world_size) if self.shard_init else None
|
|
||||||
default_dist_spec = ShardSpec([-1], [world_size]) if self.shard_init else None
|
|
||||||
return ColoInitContext(
|
|
||||||
device=get_current_device(), dtype=torch.half, default_pg=shard_pg, default_dist_spec=default_dist_spec
|
|
||||||
)
|
|
||||||
|
|
||||||
def unwrap_model(self, model: nn.Module) -> nn.Module:
|
def unwrap_model(self, model: nn.Module) -> nn.Module:
|
||||||
assert isinstance(model, GeminiModel)
|
assert isinstance(model, GeminiDDP)
|
||||||
ddp_model = model.unwrap()
|
return model.module
|
||||||
assert isinstance(ddp_model, GeminiDDP)
|
|
||||||
return ddp_model.module
|
|
||||||
|
|
|
@ -45,9 +45,17 @@ def eval(args):
|
||||||
raise ValueError(f'Unsupported model "{args.model}"')
|
raise ValueError(f'Unsupported model "{args.model}"')
|
||||||
|
|
||||||
actor.eval()
|
actor.eval()
|
||||||
|
tokenizer.padding_side = "left"
|
||||||
input_ids = tokenizer.encode(args.input, return_tensors="pt").to(torch.cuda.current_device())
|
input_ids = tokenizer.encode(args.input, return_tensors="pt").to(torch.cuda.current_device())
|
||||||
outputs = generate(
|
outputs = generate(
|
||||||
actor, input_ids, max_length=args.max_length, do_sample=True, top_k=50, top_p=0.95, num_return_sequences=1
|
actor,
|
||||||
|
input_ids,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
max_length=args.max_length,
|
||||||
|
do_sample=True,
|
||||||
|
top_k=50,
|
||||||
|
top_p=0.95,
|
||||||
|
num_return_sequences=1,
|
||||||
)
|
)
|
||||||
output = tokenizer.batch_decode(outputs[0], skip_special_tokens=True)
|
output = tokenizer.batch_decode(outputs[0], skip_special_tokens=True)
|
||||||
print(f"[Output]: {''.join(output)}")
|
print(f"[Output]: {''.join(output)}")
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
pandas>=1.4.1
|
pandas>=1.4.1
|
||||||
sentencepiece
|
sentencepiece
|
||||||
colossalai==0.3.1
|
colossalai>=0.3.1
|
||||||
|
|
|
@ -23,7 +23,7 @@ def main(args):
|
||||||
if args.strategy == "ddp":
|
if args.strategy == "ddp":
|
||||||
strategy = DDPStrategy()
|
strategy = DDPStrategy()
|
||||||
elif args.strategy == "colossalai_gemini":
|
elif args.strategy == "colossalai_gemini":
|
||||||
strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
|
strategy = GeminiStrategy(placement_policy="auto", initial_scale=2**5)
|
||||||
elif args.strategy == "colossalai_zero2":
|
elif args.strategy == "colossalai_zero2":
|
||||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
|
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
|
||||||
else:
|
else:
|
||||||
|
@ -65,8 +65,8 @@ def main(args):
|
||||||
if args.rm_path is not None:
|
if args.rm_path is not None:
|
||||||
reward_model.load_state_dict(state_dict, strict=False)
|
reward_model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
initial_model.to(torch.float16).to(torch.cuda.current_device())
|
initial_model.to(torch.bfloat16).to(torch.cuda.current_device())
|
||||||
reward_model.to(torch.float16).to(torch.cuda.current_device())
|
reward_model.to(torch.bfloat16).to(torch.cuda.current_device())
|
||||||
|
|
||||||
if args.model == "gpt2":
|
if args.model == "gpt2":
|
||||||
actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
|
actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
|
||||||
|
@ -80,13 +80,13 @@ def main(args):
|
||||||
raise ValueError(f'Unsupported actor model "{args.model}"')
|
raise ValueError(f'Unsupported actor model "{args.model}"')
|
||||||
|
|
||||||
if rm_model_name == "gpt2":
|
if rm_model_name == "gpt2":
|
||||||
critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
|
||||||
elif rm_model_name == "bloom":
|
elif rm_model_name == "bloom":
|
||||||
critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
|
||||||
elif rm_model_name == "opt":
|
elif rm_model_name == "opt":
|
||||||
critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
|
||||||
elif rm_model_name == "llama":
|
elif rm_model_name == "llama":
|
||||||
critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
|
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
|
||||||
|
|
||||||
|
@ -94,17 +94,16 @@ def main(args):
|
||||||
critic.load_state_dict(state_dict, strict=False)
|
critic.load_state_dict(state_dict, strict=False)
|
||||||
del state_dict
|
del state_dict
|
||||||
|
|
||||||
if args.strategy != "colossalai_gemini":
|
actor.to(torch.bfloat16).to(torch.cuda.current_device())
|
||||||
critic.to(torch.float16).to(torch.cuda.current_device())
|
critic.to(torch.bfloat16).to(torch.cuda.current_device())
|
||||||
actor.to(torch.float16).to(torch.cuda.current_device())
|
|
||||||
|
|
||||||
# configure optimizer
|
# configure optimizer
|
||||||
if args.strategy.startswith("colossalai"):
|
if args.strategy.startswith("colossalai"):
|
||||||
actor_optim = HybridAdam(actor.parameters(), lr=1e-7)
|
actor_optim = HybridAdam(actor.parameters(), lr=args.lr)
|
||||||
critic_optim = HybridAdam(critic.parameters(), lr=1e-7)
|
critic_optim = HybridAdam(critic.parameters(), lr=args.lr)
|
||||||
else:
|
else:
|
||||||
actor_optim = Adam(actor.parameters(), lr=1e-7)
|
actor_optim = Adam(actor.parameters(), lr=args.lr)
|
||||||
critic_optim = Adam(critic.parameters(), lr=1e-7)
|
critic_optim = Adam(critic.parameters(), lr=args.lr)
|
||||||
|
|
||||||
# configure tokenizer
|
# configure tokenizer
|
||||||
if args.model == "gpt2":
|
if args.model == "gpt2":
|
||||||
|
@ -126,8 +125,15 @@ def main(args):
|
||||||
tokenizer.pad_token = tokenizer.unk_token
|
tokenizer.pad_token = tokenizer.unk_token
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unsupported model "{args.model}"')
|
raise ValueError(f'Unsupported model "{args.model}"')
|
||||||
|
# NOTE: generate() requires padding_side to be "left"
|
||||||
|
tokenizer.padding_side = "left"
|
||||||
|
|
||||||
prompt_dataset = PromptDataset(tokenizer=tokenizer, data_path=args.prompt_dataset, max_datasets_size=16384)
|
prompt_dataset = PromptDataset(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
data_path=args.prompt_dataset,
|
||||||
|
max_datasets_size=args.max_datasets_size,
|
||||||
|
max_length=args.max_input_len,
|
||||||
|
)
|
||||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||||
prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)
|
prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)
|
||||||
else:
|
else:
|
||||||
|
@ -137,7 +143,10 @@ def main(args):
|
||||||
)
|
)
|
||||||
|
|
||||||
pretrain_dataset = SupervisedDataset(
|
pretrain_dataset = SupervisedDataset(
|
||||||
tokenizer=tokenizer, data_path=args.pretrain_dataset, max_datasets_size=16384, max_length=args.max_input_len
|
tokenizer=tokenizer,
|
||||||
|
data_path=args.pretrain_dataset,
|
||||||
|
max_datasets_size=args.max_datasets_size,
|
||||||
|
max_length=args.max_input_len,
|
||||||
)
|
)
|
||||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||||
pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True)
|
pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True)
|
||||||
|
@ -161,6 +170,7 @@ def main(args):
|
||||||
initial_model,
|
initial_model,
|
||||||
actor_optim,
|
actor_optim,
|
||||||
critic_optim,
|
critic_optim,
|
||||||
|
tokenizer=tokenizer,
|
||||||
kl_coef=args.kl_coef,
|
kl_coef=args.kl_coef,
|
||||||
ptx_coef=args.ptx_coef,
|
ptx_coef=args.ptx_coef,
|
||||||
train_batch_size=args.train_batch_size,
|
train_batch_size=args.train_batch_size,
|
||||||
|
@ -169,17 +179,17 @@ def main(args):
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
temperature=1.0,
|
temperature=1.0,
|
||||||
top_k=50,
|
top_k=50,
|
||||||
pad_token_id=tokenizer.pad_token_id,
|
|
||||||
eos_token_id=tokenizer.eos_token_id,
|
|
||||||
offload_inference_models=args.strategy != "colossalai_gemini",
|
offload_inference_models=args.strategy != "colossalai_gemini",
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer.fit(
|
trainer.fit(
|
||||||
prompt_dataloader=prompt_dataloader,
|
|
||||||
pretrain_dataloader=pretrain_dataloader,
|
|
||||||
num_episodes=args.num_episodes,
|
num_episodes=args.num_episodes,
|
||||||
num_collect_steps=args.num_collect_steps,
|
num_collect_steps=args.num_collect_steps,
|
||||||
num_update_steps=args.num_update_steps,
|
num_update_steps=args.num_update_steps,
|
||||||
|
prompt_dataloader=prompt_dataloader,
|
||||||
|
pretrain_dataloader=pretrain_dataloader,
|
||||||
|
log_dir=args.log_dir,
|
||||||
|
use_wandb=args.use_wandb,
|
||||||
)
|
)
|
||||||
|
|
||||||
# save model checkpoint after fitting
|
# save model checkpoint after fitting
|
||||||
|
@ -195,6 +205,7 @@ if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--prompt_dataset", type=str, default=None, help="path to the prompt dataset")
|
parser.add_argument("--prompt_dataset", type=str, default=None, help="path to the prompt dataset")
|
||||||
parser.add_argument("--pretrain_dataset", type=str, default=None, help="path to the pretrained dataset")
|
parser.add_argument("--pretrain_dataset", type=str, default=None, help="path to the pretrained dataset")
|
||||||
|
parser.add_argument("--max_datasets_size", type=int, default=50000)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--strategy",
|
"--strategy",
|
||||||
choices=["ddp", "colossalai_gemini", "colossalai_zero2"],
|
choices=["ddp", "colossalai_gemini", "colossalai_zero2"],
|
||||||
|
@ -216,9 +227,12 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--ptx_batch_size", type=int, default=1)
|
parser.add_argument("--ptx_batch_size", type=int, default=1)
|
||||||
parser.add_argument("--experience_batch_size", type=int, default=8)
|
parser.add_argument("--experience_batch_size", type=int, default=8)
|
||||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||||
|
parser.add_argument("--lr", type=float, default=1e-7)
|
||||||
parser.add_argument("--kl_coef", type=float, default=0.1)
|
parser.add_argument("--kl_coef", type=float, default=0.1)
|
||||||
parser.add_argument("--ptx_coef", type=float, default=0.9)
|
parser.add_argument("--ptx_coef", type=float, default=0.9)
|
||||||
parser.add_argument("--max_input_len", type=int, default=96)
|
parser.add_argument("--max_input_len", type=int, default=96)
|
||||||
parser.add_argument("--max_seq_len", type=int, default=128)
|
parser.add_argument("--max_seq_len", type=int, default=128)
|
||||||
|
parser.add_argument("--log_dir", default="logs", type=str)
|
||||||
|
parser.add_argument("--use_wandb", default=False, action="store_true")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import argparse
|
import argparse
|
||||||
from random import randint
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
@ -27,7 +26,7 @@ def train(args):
|
||||||
if args.strategy == "ddp":
|
if args.strategy == "ddp":
|
||||||
strategy = DDPStrategy()
|
strategy = DDPStrategy()
|
||||||
elif args.strategy == "colossalai_gemini":
|
elif args.strategy == "colossalai_gemini":
|
||||||
strategy = GeminiStrategy(placement_policy="cuda")
|
strategy = GeminiStrategy(placement_policy="auto")
|
||||||
elif args.strategy == "colossalai_zero2":
|
elif args.strategy == "colossalai_zero2":
|
||||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
|
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
|
||||||
else:
|
else:
|
||||||
|
@ -46,7 +45,7 @@ def train(args):
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unsupported model "{args.model}"')
|
raise ValueError(f'Unsupported model "{args.model}"')
|
||||||
|
|
||||||
model.to(torch.float16).to(torch.cuda.current_device())
|
model.to(torch.bfloat16).to(torch.cuda.current_device())
|
||||||
|
|
||||||
if args.model_path is not None:
|
if args.model_path is not None:
|
||||||
state_dict = torch.load(args.model_path)
|
state_dict = torch.load(args.model_path)
|
||||||
|
@ -75,9 +74,9 @@ def train(args):
|
||||||
|
|
||||||
# configure optimizer
|
# configure optimizer
|
||||||
if args.strategy.startswith("colossalai"):
|
if args.strategy.startswith("colossalai"):
|
||||||
optim = HybridAdam(model.parameters(), lr=5e-6)
|
optim = HybridAdam(model.parameters(), lr=args.lr)
|
||||||
else:
|
else:
|
||||||
optim = Adam(model.parameters(), lr=5e-6)
|
optim = Adam(model.parameters(), lr=args.lr)
|
||||||
|
|
||||||
# configure loss function
|
# configure loss function
|
||||||
if args.loss_fn == "log_sig":
|
if args.loss_fn == "log_sig":
|
||||||
|
@ -93,21 +92,14 @@ def train(args):
|
||||||
else:
|
else:
|
||||||
data = load_dataset(args.dataset)
|
data = load_dataset(args.dataset)
|
||||||
|
|
||||||
if args.test:
|
train_data = data["train"].select(range(min(args.max_datasets_size, len(data["train"]))))
|
||||||
train_data = data["train"].select(range(20))
|
eval_data = data["test"].select(range(min(args.max_datasets_size, len(data["test"]))))
|
||||||
eval_data = data["test"].select(range(5))
|
|
||||||
else:
|
|
||||||
train_data = data["train"]
|
|
||||||
eval_data = data["test"]
|
|
||||||
valid_data = data["test"].select((randint(0, len(eval_data) - 1) for _ in range(len(eval_data) // 5)))
|
|
||||||
|
|
||||||
if args.dataset == "Dahoas/rm-static":
|
if args.dataset == "Dahoas/rm-static":
|
||||||
train_dataset = RmStaticDataset(train_data, tokenizer, args.max_len)
|
train_dataset = RmStaticDataset(train_data, tokenizer, args.max_len)
|
||||||
valid_dataset = RmStaticDataset(valid_data, tokenizer, args.max_len)
|
|
||||||
eval_dataset = RmStaticDataset(eval_data, tokenizer, args.max_len)
|
eval_dataset = RmStaticDataset(eval_data, tokenizer, args.max_len)
|
||||||
elif args.dataset == "Anthropic/hh-rlhf":
|
elif args.dataset == "Anthropic/hh-rlhf":
|
||||||
train_dataset = HhRlhfDataset(train_data, tokenizer, args.max_len)
|
train_dataset = HhRlhfDataset(train_data, tokenizer, args.max_len)
|
||||||
valid_dataset = HhRlhfDataset(valid_data, tokenizer, args.max_len)
|
|
||||||
eval_dataset = HhRlhfDataset(eval_data, tokenizer, args.max_len)
|
eval_dataset = HhRlhfDataset(eval_data, tokenizer, args.max_len)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unsupported dataset "{args.dataset}"')
|
raise ValueError(f'Unsupported dataset "{args.dataset}"')
|
||||||
|
@ -121,14 +113,6 @@ def train(args):
|
||||||
rank=dist.get_rank(),
|
rank=dist.get_rank(),
|
||||||
num_replicas=dist.get_world_size(),
|
num_replicas=dist.get_world_size(),
|
||||||
)
|
)
|
||||||
valid_sampler = DistributedSampler(
|
|
||||||
valid_dataset,
|
|
||||||
shuffle=True,
|
|
||||||
seed=42,
|
|
||||||
drop_last=True,
|
|
||||||
rank=dist.get_rank(),
|
|
||||||
num_replicas=dist.get_world_size(),
|
|
||||||
)
|
|
||||||
eval_sampler = DistributedSampler(
|
eval_sampler = DistributedSampler(
|
||||||
eval_dataset,
|
eval_dataset,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
|
@ -139,7 +123,6 @@ def train(args):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
train_sampler = None
|
train_sampler = None
|
||||||
valid_sampler = None
|
|
||||||
eval_sampler = None
|
eval_sampler = None
|
||||||
|
|
||||||
train_dataloader = DataLoader(
|
train_dataloader = DataLoader(
|
||||||
|
@ -150,14 +133,6 @@ def train(args):
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
valid_dataloader = DataLoader(
|
|
||||||
valid_dataset,
|
|
||||||
shuffle=(valid_sampler is None),
|
|
||||||
sampler=valid_sampler,
|
|
||||||
batch_size=args.batch_size,
|
|
||||||
pin_memory=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
eval_dataloader = DataLoader(
|
eval_dataloader = DataLoader(
|
||||||
eval_dataset, shuffle=(eval_sampler is None), sampler=eval_sampler, batch_size=args.batch_size, pin_memory=True
|
eval_dataset, shuffle=(eval_sampler is None), sampler=eval_sampler, batch_size=args.batch_size, pin_memory=True
|
||||||
)
|
)
|
||||||
|
@ -176,7 +151,12 @@ def train(args):
|
||||||
max_epochs=args.max_epochs,
|
max_epochs=args.max_epochs,
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer.fit(train_dataloader=train_dataloader, valid_dataloader=valid_dataloader, eval_dataloader=eval_dataloader)
|
trainer.fit(
|
||||||
|
train_dataloader=train_dataloader,
|
||||||
|
eval_dataloader=eval_dataloader,
|
||||||
|
log_dir=args.log_dir,
|
||||||
|
use_wandb=args.use_wandb,
|
||||||
|
)
|
||||||
# save model checkpoint after fitting on only rank0
|
# save model checkpoint after fitting on only rank0
|
||||||
strategy.save_model(model, args.save_path, only_rank0=True)
|
strategy.save_model(model, args.save_path, only_rank0=True)
|
||||||
# save optimizer checkpoint on all ranks
|
# save optimizer checkpoint on all ranks
|
||||||
|
@ -200,12 +180,15 @@ if __name__ == "__main__":
|
||||||
"--dataset", type=str, choices=["Anthropic/hh-rlhf", "Dahoas/rm-static"], default="Dahoas/rm-static"
|
"--dataset", type=str, choices=["Anthropic/hh-rlhf", "Dahoas/rm-static"], default="Dahoas/rm-static"
|
||||||
)
|
)
|
||||||
parser.add_argument("--subset", type=lambda x: None if x == "None" else x, default=None)
|
parser.add_argument("--subset", type=lambda x: None if x == "None" else x, default=None)
|
||||||
|
parser.add_argument("--max_datasets_size", type=int, default=1000000)
|
||||||
parser.add_argument("--save_path", type=str, default="rm_ckpt")
|
parser.add_argument("--save_path", type=str, default="rm_ckpt")
|
||||||
parser.add_argument("--max_epochs", type=int, default=1)
|
parser.add_argument("--max_epochs", type=int, default=1)
|
||||||
parser.add_argument("--batch_size", type=int, default=1)
|
parser.add_argument("--batch_size", type=int, default=1)
|
||||||
parser.add_argument("--max_len", type=int, default=512)
|
parser.add_argument("--max_len", type=int, default=512)
|
||||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||||
|
parser.add_argument("--lr", type=float, default=9e-6)
|
||||||
parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"])
|
parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"])
|
||||||
parser.add_argument("--test", type=bool, default=False)
|
parser.add_argument("--log_dir", default="logs", type=str)
|
||||||
|
parser.add_argument("--use_wandb", default=False, action="store_true")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
train(args)
|
train(args)
|
||||||
|
|
|
@ -16,7 +16,10 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||||
set_n_least_used_CUDA_VISIBLE_DEVICES 2
|
set_n_least_used_CUDA_VISIBLE_DEVICES 2
|
||||||
|
|
||||||
torchrun --standalone --nproc_per_node=2 train_reward_model.py \
|
torchrun --standalone --nproc_per_node=2 train_reward_model.py \
|
||||||
--model 'bloom' \
|
--pretrain 'gpt2' \
|
||||||
|
--model 'gpt2' \
|
||||||
--strategy colossalai_zero2 \
|
--strategy colossalai_zero2 \
|
||||||
--loss_fn 'log_sig' \
|
--loss_fn 'log_exp' \
|
||||||
--dataset 'Anthropic/hh-rlhf'
|
--dataset 'Anthropic/hh-rlhf' \
|
||||||
|
--batch_size 16 \
|
||||||
|
--max_epochs 10
|
||||||
|
|
|
@ -23,7 +23,6 @@ from transformers.trainer import get_scheduler
|
||||||
|
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
from colossalai.tensor import ColoParameter
|
|
||||||
|
|
||||||
|
|
||||||
def train(args):
|
def train(args):
|
||||||
|
@ -31,7 +30,7 @@ def train(args):
|
||||||
if args.strategy == "ddp":
|
if args.strategy == "ddp":
|
||||||
strategy = DDPStrategy()
|
strategy = DDPStrategy()
|
||||||
elif args.strategy == "colossalai_gemini":
|
elif args.strategy == "colossalai_gemini":
|
||||||
strategy = GeminiStrategy(placement_policy="cuda")
|
strategy = GeminiStrategy(placement_policy="auto")
|
||||||
elif args.strategy == "colossalai_zero2":
|
elif args.strategy == "colossalai_zero2":
|
||||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
|
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
|
||||||
elif args.strategy == "colossalai_zero2_cpu":
|
elif args.strategy == "colossalai_zero2_cpu":
|
||||||
|
@ -57,7 +56,7 @@ def train(args):
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unsupported model "{args.model}"')
|
raise ValueError(f'Unsupported model "{args.model}"')
|
||||||
|
|
||||||
model.to(torch.float16).to(torch.cuda.current_device())
|
model.to(torch.bfloat16).to(torch.cuda.current_device())
|
||||||
|
|
||||||
# configure tokenizer
|
# configure tokenizer
|
||||||
if args.model == "gpt2":
|
if args.model == "gpt2":
|
||||||
|
@ -84,28 +83,21 @@ def train(args):
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unsupported model "{args.model}"')
|
raise ValueError(f'Unsupported model "{args.model}"')
|
||||||
|
|
||||||
if args.model == "llama" and args.strategy == "colossalai_gemini":
|
|
||||||
# this is a hack to deal with the resized embedding
|
|
||||||
# to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatibility
|
|
||||||
for name, param in model.named_parameters():
|
|
||||||
if not isinstance(param, ColoParameter):
|
|
||||||
sub_module_name = ".".join(name.split(".")[:-1])
|
|
||||||
weight_name = name.split(".")[-1]
|
|
||||||
sub_module = model.get_submodule(sub_module_name)
|
|
||||||
setattr(sub_module, weight_name, ColoParameter(param))
|
|
||||||
|
|
||||||
# configure optimizer
|
# configure optimizer
|
||||||
if args.strategy.startswith("colossalai"):
|
if args.strategy.startswith("colossalai"):
|
||||||
optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0)
|
optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0)
|
||||||
else:
|
else:
|
||||||
optim = Adam(model.parameters(), lr=args.lr)
|
optim = Adam(model.parameters(), lr=args.lr)
|
||||||
logger = get_dist_logger()
|
|
||||||
|
|
||||||
# configure dataset
|
# configure dataset
|
||||||
if args.dataset == "yizhongw/self_instruct":
|
if args.dataset == "yizhongw/self_instruct":
|
||||||
train_data = load_dataset(args.dataset, "super_natural_instructions", split="train")
|
train_data = load_dataset(args.dataset, "super_natural_instructions", split="train")
|
||||||
eval_data = load_dataset(args.dataset, "super_natural_instructions", split="test")
|
eval_data = load_dataset(args.dataset, "super_natural_instructions", split="test")
|
||||||
|
|
||||||
|
if args.max_datasets_size is not None:
|
||||||
|
train_data = train_data.select(range(min(args.max_datasets_size, len(train_data))))
|
||||||
|
eval_data = eval_data.select(range(min(args.max_datasets_size, len(eval_data))))
|
||||||
|
|
||||||
train_dataset = SFTDataset(train_data, tokenizer, args.max_len)
|
train_dataset = SFTDataset(train_data, tokenizer, args.max_len)
|
||||||
eval_dataset = SFTDataset(eval_data, tokenizer, args.max_len)
|
eval_dataset = SFTDataset(eval_data, tokenizer, args.max_len)
|
||||||
|
|
||||||
|
@ -176,8 +168,13 @@ def train(args):
|
||||||
accumulation_steps=args.accumulation_steps,
|
accumulation_steps=args.accumulation_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = get_dist_logger()
|
||||||
trainer.fit(
|
trainer.fit(
|
||||||
train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, logger=logger, use_wandb=args.use_wandb
|
train_dataloader=train_dataloader,
|
||||||
|
eval_dataloader=eval_dataloader,
|
||||||
|
logger=logger,
|
||||||
|
log_dir=args.log_dir,
|
||||||
|
use_wandb=args.use_wandb,
|
||||||
)
|
)
|
||||||
|
|
||||||
# save model checkpoint after fitting on only rank0
|
# save model checkpoint after fitting on only rank0
|
||||||
|
@ -207,9 +204,9 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--batch_size", type=int, default=4)
|
parser.add_argument("--batch_size", type=int, default=4)
|
||||||
parser.add_argument("--max_len", type=int, default=512)
|
parser.add_argument("--max_len", type=int, default=512)
|
||||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||||
parser.add_argument("--log_interval", type=int, default=100, help="how many steps to log")
|
|
||||||
parser.add_argument("--lr", type=float, default=5e-6)
|
parser.add_argument("--lr", type=float, default=5e-6)
|
||||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||||
|
parser.add_argument("--log_dir", default="logs", type=str)
|
||||||
parser.add_argument("--use_wandb", default=False, action="store_true")
|
parser.add_argument("--use_wandb", default=False, action="store_true")
|
||||||
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
|
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -19,7 +19,6 @@ torchrun --standalone --nproc_per_node=4 train_sft.py \
|
||||||
--pretrain "/path/to/LLaMa-7B/" \
|
--pretrain "/path/to/LLaMa-7B/" \
|
||||||
--model 'llama' \
|
--model 'llama' \
|
||||||
--strategy colossalai_zero2 \
|
--strategy colossalai_zero2 \
|
||||||
--log_interval 10 \
|
|
||||||
--save_path /path/to/Coati-7B \
|
--save_path /path/to/Coati-7B \
|
||||||
--dataset /path/to/data.json \
|
--dataset /path/to/data.json \
|
||||||
--batch_size 4 \
|
--batch_size 4 \
|
||||||
|
|
|
@ -1,2 +1,2 @@
|
||||||
pytest
|
pytest
|
||||||
colossalai==0.3.1
|
colossalai>=0.3.1
|
||||||
|
|
|
@ -2,7 +2,7 @@ transformers>=4.20.1
|
||||||
tqdm
|
tqdm
|
||||||
datasets
|
datasets
|
||||||
loralib
|
loralib
|
||||||
colossalai==0.3.1
|
colossalai>=0.3.1
|
||||||
torch<2.0.0, >=1.12.1
|
torch<2.0.0, >=1.12.1
|
||||||
langchain
|
langchain
|
||||||
tokenizers
|
tokenizers
|
||||||
|
@ -11,3 +11,4 @@ sse_starlette
|
||||||
wandb
|
wandb
|
||||||
sentencepiece
|
sentencepiece
|
||||||
gpustat
|
gpustat
|
||||||
|
tensorboard
|
||||||
|
|
|
@ -25,8 +25,8 @@ def get_data(batch_size: int, seq_len: int = 10) -> dict:
|
||||||
def train_step(strategy: Strategy, actor: GPTActor, actor_optim: HybridAdam, batch_size: int = 8):
|
def train_step(strategy: Strategy, actor: GPTActor, actor_optim: HybridAdam, batch_size: int = 8):
|
||||||
data = get_data(batch_size)
|
data = get_data(batch_size)
|
||||||
action_mask = torch.ones_like(data["attention_mask"], dtype=torch.bool)
|
action_mask = torch.ones_like(data["attention_mask"], dtype=torch.bool)
|
||||||
actor_output = actor(data["input_ids"], data["attention_mask"])
|
actor_logits = actor(data["input_ids"], data["attention_mask"])["logits"]
|
||||||
action_log_probs = calc_action_log_probs(actor_output, data["input_ids"], action_mask.size(1))
|
action_log_probs = calc_action_log_probs(actor_logits, data["input_ids"], action_mask.size(1))
|
||||||
loss = action_log_probs.sum()
|
loss = action_log_probs.sum()
|
||||||
strategy.backward(loss, actor, actor_optim)
|
strategy.backward(loss, actor, actor_optim)
|
||||||
strategy.optimizer_step(actor_optim)
|
strategy.optimizer_step(actor_optim)
|
||||||
|
@ -36,7 +36,7 @@ def run_test_checkpoint(strategy_name: str, shard: bool):
|
||||||
if strategy_name == "ddp":
|
if strategy_name == "ddp":
|
||||||
strategy = DDPStrategy()
|
strategy = DDPStrategy()
|
||||||
elif strategy_name == "colossalai_gemini":
|
elif strategy_name == "colossalai_gemini":
|
||||||
strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
|
strategy = GeminiStrategy(placement_policy="auto", initial_scale=2**5)
|
||||||
elif strategy_name == "colossalai_zero2":
|
elif strategy_name == "colossalai_zero2":
|
||||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
|
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -226,7 +226,9 @@ def test_sft_dataset(model: str, dataset_path: Optional[str], max_dataset_size:
|
||||||
check_content(input_ids.masked_select(attention_mask), tokenizer, model)
|
check_content(input_ids.masked_select(attention_mask), tokenizer, model)
|
||||||
assert torch.all(attention_mask)
|
assert torch.all(attention_mask)
|
||||||
ignore_mask = labels == IGNORE_INDEX
|
ignore_mask = labels == IGNORE_INDEX
|
||||||
check_content(input_ids.masked_select(ignore_mask), tokenizer, model)
|
prompt_mask = torch.logical_and(ignore_mask, attention_mask)
|
||||||
|
check_content(input_ids.masked_select(prompt_mask), tokenizer, model)
|
||||||
|
assert torch.all(input_ids.masked_select(ignore_mask ^ prompt_mask) == tokenizer.pad_token_id)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
|
import copy
|
||||||
import os
|
import os
|
||||||
from copy import deepcopy
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
@ -8,6 +8,7 @@ from coati.experience_buffer import NaiveExperienceBuffer
|
||||||
from coati.experience_maker import NaiveExperienceMaker
|
from coati.experience_maker import NaiveExperienceMaker
|
||||||
from coati.models.base import RewardModel
|
from coati.models.base import RewardModel
|
||||||
from coati.models.gpt import GPTActor, GPTCritic
|
from coati.models.gpt import GPTActor, GPTCritic
|
||||||
|
from coati.trainer.ppo import _set_default_generate_kwargs
|
||||||
from coati.trainer.strategies import DDPStrategy, GeminiStrategy
|
from coati.trainer.strategies import DDPStrategy, GeminiStrategy
|
||||||
from coati.trainer.strategies.colossalai import LowLevelZeroStrategy
|
from coati.trainer.strategies.colossalai import LowLevelZeroStrategy
|
||||||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
||||||
|
@ -42,27 +43,38 @@ def make_and_consume_experience(strategy):
|
||||||
elif strategy == "colossalai-zero2":
|
elif strategy == "colossalai-zero2":
|
||||||
strategy = LowLevelZeroStrategy()
|
strategy = LowLevelZeroStrategy()
|
||||||
elif strategy == "colossalai-gemini":
|
elif strategy == "colossalai-gemini":
|
||||||
strategy = GeminiStrategy(placement_policy="cuda")
|
strategy = GeminiStrategy(placement_policy="static")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unsupported strategy "{strategy}"')
|
raise ValueError(f'Unsupported strategy "{strategy}"')
|
||||||
|
|
||||||
actor = GPTActor(config=GPT_CONFIG).cuda()
|
with strategy.model_init_context():
|
||||||
critic = GPTCritic(config=GPT_CONFIG).cuda()
|
actor = GPTActor(config=GPT_CONFIG).cuda()
|
||||||
|
critic = GPTCritic(config=GPT_CONFIG).cuda()
|
||||||
|
|
||||||
initial_model = deepcopy(actor)
|
initial_model = GPTActor(config=GPT_CONFIG).cuda()
|
||||||
reward_model = RewardModel(deepcopy(critic.model)).cuda()
|
reward_model = RewardModel(model=copy.deepcopy(critic.model)).cuda()
|
||||||
|
|
||||||
experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model)
|
actor, critic, initial_model, reward_model = strategy.prepare(actor, critic, initial_model, reward_model)
|
||||||
|
|
||||||
|
class MockTokenizer:
|
||||||
|
def __init__(self):
|
||||||
|
self.padding_side = "left"
|
||||||
|
self.eos_token_id = 0
|
||||||
|
self.pad_token_id = 0
|
||||||
|
|
||||||
|
tokenizer = MockTokenizer()
|
||||||
|
experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, tokenizer)
|
||||||
data_buffer = NaiveExperienceBuffer(SAMPLE_BATCH_SIZE, cpu_offload=False)
|
data_buffer = NaiveExperienceBuffer(SAMPLE_BATCH_SIZE, cpu_offload=False)
|
||||||
|
|
||||||
|
generate_kwargs = dict(do_sample=True, max_length=16)
|
||||||
|
generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
|
||||||
|
|
||||||
# experience of all ranks should be the same
|
# experience of all ranks should be the same
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
data = get_data(EXPERIENCE_BATCH_SIZE)
|
data = get_data(EXPERIENCE_BATCH_SIZE)
|
||||||
assert gather_and_equal(data["input_ids"])
|
assert gather_and_equal(data["input_ids"])
|
||||||
assert gather_and_equal(data["attention_mask"])
|
assert gather_and_equal(data["attention_mask"])
|
||||||
experience = experience_maker.make_experience(
|
experience = experience_maker.make_experience(**data, do_sample=True, max_length=16)
|
||||||
**data, do_sample=True, max_length=16, eos_token_id=50256, pad_token_id=50256
|
|
||||||
)
|
|
||||||
assert gather_and_equal(experience.sequences)
|
assert gather_and_equal(experience.sequences)
|
||||||
assert gather_and_equal(experience.action_log_probs)
|
assert gather_and_equal(experience.action_log_probs)
|
||||||
assert gather_and_equal(experience.values)
|
assert gather_and_equal(experience.values)
|
||||||
|
@ -115,4 +127,4 @@ def test_experience(world_size, strategy):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_experience(2, "colossalai")
|
test_experience(2, "colossalai-zero2")
|
||||||
|
|
|
@ -14,7 +14,7 @@ from coati.models.llama import LlamaActor
|
||||||
from coati.models.lora import LoraLinear, convert_to_lora_module
|
from coati.models.lora import LoraLinear, convert_to_lora_module
|
||||||
from coati.models.loss import GPTLMLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
|
from coati.models.loss import GPTLMLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
|
||||||
from coati.models.opt import OPTRM, OPTActor, OPTCritic
|
from coati.models.opt import OPTRM, OPTActor, OPTCritic
|
||||||
from coati.models.utils import calc_action_log_probs, compute_reward, masked_mean
|
from coati.models.utils import calc_action_log_probs, masked_mean
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("batch_size", [4])
|
@pytest.mark.parametrize("batch_size", [4])
|
||||||
|
@ -27,7 +27,6 @@ from coati.models.utils import calc_action_log_probs, compute_reward, masked_mea
|
||||||
# HACK: skip llama due to long execution time
|
# HACK: skip llama due to long execution time
|
||||||
# lambda: LlamaActor(),
|
# lambda: LlamaActor(),
|
||||||
lambda: OPTActor(),
|
lambda: OPTActor(),
|
||||||
# lambda: ChatGLMActor(),
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -43,9 +42,16 @@ from coati.models.utils import calc_action_log_probs, compute_reward, masked_mea
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_generation(actor_maker: Callable[[], Actor], batch_size: int, seq_len: int, generate_kwargs: Dict[str, Any]):
|
def test_generation(actor_maker: Callable[[], Actor], batch_size: int, seq_len: int, generate_kwargs: Dict[str, Any]):
|
||||||
|
class MockTokenizer:
|
||||||
|
def __init__(self):
|
||||||
|
self.padding_side = "left"
|
||||||
|
self.eos_token_id = 0
|
||||||
|
self.pad_token_id = 0
|
||||||
|
|
||||||
actor = actor_maker()
|
actor = actor_maker()
|
||||||
input_ids = torch.randint(0, 100, (batch_size, seq_len)).cuda()
|
input_ids = torch.randint(0, 100, (batch_size, seq_len)).cuda()
|
||||||
sequences = generate(actor.cuda(), input_ids, **generate_kwargs)
|
tokenizer = MockTokenizer()
|
||||||
|
sequences = generate(actor.cuda(), input_ids, tokenizer, **generate_kwargs)
|
||||||
assert sequences.shape == (batch_size, generate_kwargs["max_length"])
|
assert sequences.shape == (batch_size, generate_kwargs["max_length"])
|
||||||
|
|
||||||
|
|
||||||
|
@ -55,24 +61,12 @@ def test_utils():
|
||||||
assert fn_output.dim() == 0
|
assert fn_output.dim() == 0
|
||||||
assert torch.allclose(fn_output, torch.tensor(1.0))
|
assert torch.allclose(fn_output, torch.tensor(1.0))
|
||||||
|
|
||||||
batch_size = 4
|
|
||||||
num_labels = 10
|
|
||||||
fn_input = {
|
|
||||||
"r": torch.ones((batch_size,)),
|
|
||||||
"kl_coef": 1.0,
|
|
||||||
"log_probs": torch.randn((batch_size, num_labels)),
|
|
||||||
"log_probs_base": torch.randn((batch_size, num_labels)),
|
|
||||||
"action_mask": torch.randint(0, 2, (batch_size, num_labels)),
|
|
||||||
}
|
|
||||||
fn_output = compute_reward(**fn_input)
|
|
||||||
assert fn_output.shape == (batch_size,)
|
|
||||||
|
|
||||||
batch_size = 4
|
batch_size = 4
|
||||||
seq_len = 32
|
seq_len = 32
|
||||||
num_labels = 10
|
num_labels = 10
|
||||||
num_actions = 2
|
num_actions = 2
|
||||||
fn_input = {
|
fn_input = {
|
||||||
"output": {"logits": torch.randn((batch_size, seq_len, num_labels))},
|
"logits": torch.randn((batch_size, seq_len, num_labels)),
|
||||||
"sequences": torch.randint(0, num_labels, (batch_size, seq_len)),
|
"sequences": torch.randint(0, num_labels, (batch_size, seq_len)),
|
||||||
"num_actions": num_actions,
|
"num_actions": num_actions,
|
||||||
}
|
}
|
||||||
|
@ -135,7 +129,6 @@ def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]], b
|
||||||
}
|
}
|
||||||
critic_input = {
|
critic_input = {
|
||||||
"sequences": torch.randint(0, 100, (batch_size, seq_len)),
|
"sequences": torch.randint(0, 100, (batch_size, seq_len)),
|
||||||
"action_mask": torch.randint(0, 2, (batch_size, seq_len)),
|
|
||||||
"attention_mask": torch.randint(0, 2, (batch_size, seq_len)),
|
"attention_mask": torch.randint(0, 2, (batch_size, seq_len)),
|
||||||
}
|
}
|
||||||
rm_input = {
|
rm_input = {
|
||||||
|
|
|
@ -24,8 +24,8 @@ if [ -z "$SFT_DATASET" ]; then
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ -z "$PROMPT_PATH" ]; then
|
if [ -z "$PROMPT_DATASET" ]; then
|
||||||
echo "Please set \$PROMPT_PATH to the path to prompts csv."
|
echo "Please set \$PROMPT_DATASET to the path to prompts csv."
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
@ -74,11 +74,15 @@ echo "[Test]: testing sft ..."
|
||||||
# FIXME: This is a hack to skip tests that are not working
|
# FIXME: This is a hack to skip tests that are not working
|
||||||
# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
|
# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
|
||||||
# - llama-*: These tests can be passed locally, skipped for long execution time
|
# - llama-*: These tests can be passed locally, skipped for long execution time
|
||||||
|
# - *-gemini: Gemini plugin does not support `from_pretrained` yet
|
||||||
SKIPPED_TESTS=(
|
SKIPPED_TESTS=(
|
||||||
"gpt2-ddp"
|
"gpt2-ddp"
|
||||||
"llama-ddp"
|
"llama-ddp"
|
||||||
"llama-colossalai_gemini"
|
"llama-colossalai_gemini"
|
||||||
"llama-colossalai_zero2"
|
"llama-colossalai_zero2"
|
||||||
|
"gpt2-colossalai_gemini"
|
||||||
|
"opt-colossalai_gemini"
|
||||||
|
"bloom-colossalai_gemini"
|
||||||
)
|
)
|
||||||
|
|
||||||
GRAD_CKPTS=('' '--grad_checkpoint')
|
GRAD_CKPTS=('' '--grad_checkpoint')
|
||||||
|
@ -105,7 +109,7 @@ for lora_rank in '0' '4'; do
|
||||||
$pretrain_model --tokenizer $MODELS_DIR/$model \
|
$pretrain_model --tokenizer $MODELS_DIR/$model \
|
||||||
--model $model --strategy $strategy --lora_rank $lora_rank $grad_ckpt \
|
--model $model --strategy $strategy --lora_rank $lora_rank $grad_ckpt \
|
||||||
--dataset $SFT_DATASET --max_datasets_size 8 \
|
--dataset $SFT_DATASET --max_datasets_size 8 \
|
||||||
--max_epochs 1 --batch_size 1 --accumulation_steps 1 \
|
--max_epochs 1 --batch_size 1 --accumulation_steps 1 --lr 1e-8 \
|
||||||
--save_path $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank}
|
--save_path $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank}
|
||||||
passed=$?
|
passed=$?
|
||||||
if [ $passed -eq 0 ]; then
|
if [ $passed -eq 0 ]; then
|
||||||
|
@ -125,11 +129,15 @@ echo "[Test]: testing reward model ..."
|
||||||
# FIXME: This is a hack to skip tests that are not working
|
# FIXME: This is a hack to skip tests that are not working
|
||||||
# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
|
# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
|
||||||
# - llama-*: These tests can be passed locally, skipped for long execution time
|
# - llama-*: These tests can be passed locally, skipped for long execution time
|
||||||
|
# - *-gemini: Gemini plugin does not support `from_pretrained` yet
|
||||||
SKIPPED_TESTS=(
|
SKIPPED_TESTS=(
|
||||||
"gpt2-ddp"
|
"gpt2-ddp"
|
||||||
"llama-ddp"
|
"llama-ddp"
|
||||||
"llama-colossalai_gemini"
|
"llama-colossalai_gemini"
|
||||||
"llama-colossalai_zero2"
|
"llama-colossalai_zero2"
|
||||||
|
"gpt2-colossalai_gemini"
|
||||||
|
"opt-colossalai_gemini"
|
||||||
|
"bloom-colossalai_gemini"
|
||||||
)
|
)
|
||||||
|
|
||||||
LOSS_FNS=('log_sig' 'log_exp')
|
LOSS_FNS=('log_sig' 'log_exp')
|
||||||
|
@ -157,8 +165,9 @@ for lora_rank in '0' '4'; do
|
||||||
echo "[Test]: $model-$strategy-$lora_rank, attempt $i"
|
echo "[Test]: $model-$strategy-$lora_rank, attempt $i"
|
||||||
torchrun --standalone --nproc_per_node=4 $EXAMPLES_DIR/train_reward_model.py \
|
torchrun --standalone --nproc_per_node=4 $EXAMPLES_DIR/train_reward_model.py \
|
||||||
$pretrain_model --tokenizer $MODELS_DIR/$model \
|
$pretrain_model --tokenizer $MODELS_DIR/$model \
|
||||||
--model $model --strategy $strategy --lora_rank $lora_rank --loss_fn $loss_fn \
|
--dataset $dataset --subset $subset --max_datasets_size 8 \
|
||||||
--dataset $dataset --subset $subset --test True --batch_size 1 \
|
--model $model --strategy $strategy --lora_rank $lora_rank \
|
||||||
|
--loss_fn $loss_fn --batch_size 1 --lr 1e-8 \
|
||||||
--save_path $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt
|
--save_path $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt
|
||||||
passed=$?
|
passed=$?
|
||||||
if [ $passed -eq 0 ]; then
|
if [ $passed -eq 0 ]; then
|
||||||
|
@ -178,11 +187,15 @@ echo "[Test]: testing RLHF ..."
|
||||||
# FIXME: This is a hack to skip tests that are not working
|
# FIXME: This is a hack to skip tests that are not working
|
||||||
# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
|
# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
|
||||||
# - llama-*: These tests can be passed locally, skipped for long execution time
|
# - llama-*: These tests can be passed locally, skipped for long execution time
|
||||||
|
# - *-gemini: Gemini plugin does not support `from_pretrained` yet
|
||||||
SKIPPED_TESTS=(
|
SKIPPED_TESTS=(
|
||||||
"gpt2-ddp"
|
"gpt2-ddp"
|
||||||
"llama-ddp"
|
"llama-ddp"
|
||||||
"llama-colossalai_gemini"
|
"llama-colossalai_gemini"
|
||||||
"llama-colossalai_zero2"
|
"llama-colossalai_zero2"
|
||||||
|
"gpt2-colossalai_gemini"
|
||||||
|
"opt-colossalai_gemini"
|
||||||
|
"bloom-colossalai_gemini"
|
||||||
)
|
)
|
||||||
|
|
||||||
for model in ${MODELS[@]}; do
|
for model in ${MODELS[@]}; do
|
||||||
|
@ -204,9 +217,9 @@ for model in ${MODELS[@]}; do
|
||||||
for i in $(seq $NUM_RETRY); do
|
for i in $(seq $NUM_RETRY); do
|
||||||
echo "[Test]: $model-$strategy-$lora_rank, attempt $i"
|
echo "[Test]: $model-$strategy-$lora_rank, attempt $i"
|
||||||
torchrun --standalone --nproc_per_node=4 $EXAMPLES_DIR/train_prompts.py \
|
torchrun --standalone --nproc_per_node=4 $EXAMPLES_DIR/train_prompts.py \
|
||||||
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
|
--prompt_dataset $PROMPT_DATASET --pretrain_dataset $PRETRAIN_DATASET --max_datasets_size 32 \
|
||||||
--strategy $strategy --model $model --tokenizer $MODELS_DIR/$model \
|
--strategy $strategy --model $model --tokenizer $MODELS_DIR/$model \
|
||||||
--num_episodes 1 --num_collect_steps 1 --num_update_steps 1 \
|
--num_episodes 1 --num_collect_steps 1 --num_update_steps 1 --lr 1e-8 \
|
||||||
--experience_batch_size 2 --train_batch_size 1 --lora_rank $lora_rank \
|
--experience_batch_size 2 --train_batch_size 1 --lora_rank $lora_rank \
|
||||||
--pretrain $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank} \
|
--pretrain $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank} \
|
||||||
$rm_pretrain_model --rm_path $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt \
|
$rm_pretrain_model --rm_path $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt \
|
||||||
|
|
|
@ -3,6 +3,7 @@ import time
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
from model_zoo import GPTLMLoss, get_gpt2_components
|
||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_map
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
|
@ -13,7 +14,6 @@ from colossalai.fx.profiler import parameter_size
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
from colossalai.testing import spawn
|
from colossalai.testing import spawn
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from model_zoo import GPTLMLoss, get_gpt2_components
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
|
|
|
@ -3,6 +3,7 @@ import time
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from model_zoo import model_builder
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from colossalai.fx import ColoTracer
|
from colossalai.fx import ColoTracer
|
||||||
|
@ -12,7 +13,6 @@ from colossalai.legacy.pipeline.middleware.adaptor import get_fx_topology
|
||||||
from colossalai.legacy.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine
|
from colossalai.legacy.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine
|
||||||
from colossalai.legacy.pipeline.rpc.utils import rpc_run
|
from colossalai.legacy.pipeline.rpc.utils import rpc_run
|
||||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||||
from model_zoo import model_builder
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
|
|
Loading…
Reference in New Issue