mirror of https://github.com/hpcaitech/ColossalAI
[chat] remove lm model class (#3653)
* [chat] refactor lora * [chat] remove lm class * [chat] refactor save model * [chat] refactor train sft * [chat] fix ci * [chat] fix cipull/3662/head
parent
8bccb72c8d
commit
6ef7011462
|
@ -1,4 +1,8 @@
|
|||
from .base import Actor, Critic, RewardModel
|
||||
from .lora import LoRAModule, convert_to_lora_module
|
||||
from .loss import LogExpLoss, LogSigLoss, PolicyLoss, PPOPtxActorLoss, ValueLoss
|
||||
|
||||
__all__ = ['Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'LogSigLoss', 'LogExpLoss']
|
||||
__all__ = [
|
||||
'Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'LogSigLoss', 'LogExpLoss',
|
||||
'LoRAModule', 'convert_to_lora_module'
|
||||
]
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
from .actor import Actor
|
||||
from .critic import Critic
|
||||
from .lm import LM
|
||||
from .reward_model import RewardModel
|
||||
|
||||
__all__ = ['Actor', 'Critic', 'RewardModel', 'LM']
|
||||
__all__ = ['Actor', 'Critic', 'RewardModel']
|
||||
|
|
|
@ -1,30 +0,0 @@
|
|||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..generation import generate
|
||||
from .actor import Actor
|
||||
|
||||
|
||||
class LM(Actor):
|
||||
"""
|
||||
Language model base class.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Language Model.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str = 'none') -> None:
|
||||
super().__init__(model=model, lora_rank=lora_rank, lora_train_bias=lora_train_bias)
|
||||
|
||||
def forward(self, sequences: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""Returns output log probs
|
||||
"""
|
||||
output = self.model(sequences, attention_mask=attention_mask)
|
||||
logits = output['logits']
|
||||
log_probs = F.log_softmax(logits, dim=-1)
|
||||
return log_probs
|
|
@ -1,6 +1,5 @@
|
|||
from .bloom_actor import BLOOMActor
|
||||
from .bloom_critic import BLOOMCritic
|
||||
from .bloom_lm import BLOOMLM
|
||||
from .bloom_rm import BLOOMRM
|
||||
|
||||
__all__ = ['BLOOMActor', 'BLOOMCritic', 'BLOOMRM', 'BLOOMLM']
|
||||
__all__ = ['BLOOMActor', 'BLOOMCritic', 'BLOOMRM']
|
||||
|
|
|
@ -1,38 +0,0 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from transformers import BloomConfig, BloomForCausalLM, BloomModel
|
||||
|
||||
from ..base import LM
|
||||
|
||||
|
||||
class BLOOMLM(LM):
|
||||
"""
|
||||
BLOOM language model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (BloomConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: str = None,
|
||||
config: Optional[BloomConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
model = BloomForCausalLM.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = BloomForCausalLM(config)
|
||||
else:
|
||||
model = BloomForCausalLM(BloomConfig())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
super().__init__(model, lora_rank, lora_train_bias)
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
|
||||
return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs)
|
|
@ -1,6 +1,5 @@
|
|||
from .gpt_actor import GPTActor
|
||||
from .gpt_critic import GPTCritic
|
||||
from .gpt_lm import GPTLM
|
||||
from .gpt_rm import GPTRM
|
||||
|
||||
__all__ = ['GPTActor', 'GPTCritic', 'GPTRM', 'GPTLM']
|
||||
__all__ = ['GPTActor', 'GPTCritic', 'GPTRM']
|
||||
|
|
|
@ -1,38 +0,0 @@
|
|||
from typing import Optional
|
||||
|
||||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
||||
|
||||
from ..base import LM
|
||||
|
||||
|
||||
class GPTLM(LM):
|
||||
"""
|
||||
GPT language model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (GPT2Config): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the LoRa layer.
|
||||
lora_train_bias (str): Bias training strategy for the LoRa layer.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[GPT2Config] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
model = GPT2LMHeadModel.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = GPT2LMHeadModel(config)
|
||||
else:
|
||||
model = GPT2LMHeadModel(GPT2Config())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
super().__init__(model, lora_rank, lora_train_bias)
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
|
||||
return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs)
|
|
@ -1,6 +1,5 @@
|
|||
from .llama_actor import LlamaActor
|
||||
from .llama_critic import LlamaCritic
|
||||
from .llama_lm import LlamaLM
|
||||
from .llama_rm import LlamaRM
|
||||
|
||||
__all__ = ['LlamaActor', 'LlamaCritic', 'LlamaRM', 'LlamaLM']
|
||||
__all__ = ['LlamaActor', 'LlamaCritic', 'LlamaRM']
|
||||
|
|
|
@ -1,40 +0,0 @@
|
|||
from typing import Optional
|
||||
|
||||
from transformers import LlamaConfig, LlamaForCausalLM
|
||||
|
||||
from ..base import LM
|
||||
|
||||
|
||||
class LlamaLM(LM):
|
||||
"""
|
||||
Llama language model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (LlamaConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[LlamaConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
|
||||
if pretrained is not None:
|
||||
model = LlamaForCausalLM.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = LlamaForCausalLM(config)
|
||||
else:
|
||||
model = LlamaForCausalLM(LlamaConfig())
|
||||
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
super().__init__(model, lora_rank, lora_train_bias)
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
|
||||
return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs)
|
|
@ -106,6 +106,23 @@ def convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
|
|||
convert_to_lora_recursively(child, lora_rank)
|
||||
|
||||
|
||||
def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = 'none') -> nn.Module:
|
||||
"""Convert a torch.nn.Module to a LoRA module.
|
||||
|
||||
Args:
|
||||
module (nn.Module): The module to convert.
|
||||
lora_rank (int): LoRA rank.
|
||||
|
||||
Returns:
|
||||
nn.Module: The converted module.
|
||||
"""
|
||||
if lora_rank <= 0:
|
||||
return module
|
||||
convert_to_lora_recursively(module, lora_rank)
|
||||
lora.mark_only_lora_as_trainable(module, lora_train_bias)
|
||||
return module
|
||||
|
||||
|
||||
class LoRAModule(nn.Module):
|
||||
"""A LoRA module base class. All derived classes should call `convert_to_lora()` at the bottom of `__init__()`.
|
||||
This class will convert all torch.nn.Linear layer to LoraLinear layer.
|
||||
|
@ -123,7 +140,4 @@ class LoRAModule(nn.Module):
|
|||
self.lora_train_bias = lora_train_bias
|
||||
|
||||
def convert_to_lora(self) -> None:
|
||||
if self.lora_rank <= 0:
|
||||
return
|
||||
convert_to_lora_recursively(self, self.lora_rank)
|
||||
lora.mark_only_lora_as_trainable(self, self.lora_train_bias)
|
||||
convert_to_lora_module(self, self.lora_rank, self.lora_train_bias)
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
from .opt_actor import OPTActor
|
||||
from .opt_critic import OPTCritic
|
||||
from .opt_lm import OPTLM
|
||||
from .opt_rm import OPTRM
|
||||
|
||||
__all__ = ['OPTActor', 'OPTCritic', 'OPTRM', 'OPTLM']
|
||||
__all__ = ['OPTActor', 'OPTCritic', 'OPTRM']
|
||||
|
|
|
@ -1,38 +0,0 @@
|
|||
from typing import Optional
|
||||
|
||||
from transformers.models.opt.configuration_opt import OPTConfig
|
||||
from transformers.models.opt.modeling_opt import OPTForCausalLM
|
||||
|
||||
from ..base import LM
|
||||
|
||||
|
||||
class OPTLM(LM):
|
||||
"""
|
||||
OPT language model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (OPTConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the low-rank approximation.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[OPTConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
model = OPTForCausalLM.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = OPTForCausalLM(config)
|
||||
else:
|
||||
model = OPTForCausalLM(OPTConfig())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
super().__init__(model, lora_rank, lora_train_bias)
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
|
||||
return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs)
|
|
@ -2,26 +2,19 @@ import math
|
|||
import time
|
||||
from typing import List, Optional
|
||||
|
||||
import loralib as lora
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import wandb
|
||||
from coati.models.loss import GPTLMLoss
|
||||
from torch import nn
|
||||
from torch.optim import Adam, Optimizer
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||
from transformers.trainer import get_scheduler
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from .base import Trainer
|
||||
from .callbacks import Callback
|
||||
from .strategies import Strategy
|
||||
from .utils import is_rank_0
|
||||
from .strategies import ColossalAIStrategy, Strategy
|
||||
from .utils import is_rank_0, to_device
|
||||
|
||||
|
||||
class SFTTrainer(Trainer):
|
||||
|
@ -47,19 +40,17 @@ class SFTTrainer(Trainer):
|
|||
optim: Optimizer,
|
||||
train_dataloader: DataLoader,
|
||||
eval_dataloader: DataLoader = None,
|
||||
batch_size: int = 1,
|
||||
max_epochs: int = 2,
|
||||
accimulation_steps: int = 8,
|
||||
callbacks: List[Callback] = [],
|
||||
) -> None:
|
||||
if accimulation_steps > 1 and isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3:
|
||||
raise ValueError("Accumulation steps are not supported in stage 3 of ColossalAI")
|
||||
super().__init__(strategy, max_epochs, callbacks=callbacks)
|
||||
self.train_dataloader = train_dataloader
|
||||
self.eval_dataloader = eval_dataloader
|
||||
|
||||
self.model = strategy.setup_model(model)
|
||||
if "DDP" in str(self.strategy):
|
||||
self.model = self.model.module
|
||||
self.optimizer = strategy.setup_optimizer(optim, self.model)
|
||||
(self.model, self.optimizer) = strategy.prepare((model, optim))
|
||||
|
||||
self.accimulation_steps = accimulation_steps
|
||||
num_update_steps_per_epoch = len(train_dataloader) // self.accimulation_steps
|
||||
|
@ -86,17 +77,10 @@ class SFTTrainer(Trainer):
|
|||
self.model.train()
|
||||
for batch_id, batch in enumerate(self.train_dataloader):
|
||||
|
||||
prompt_ids = batch["input_ids"].to(torch.cuda.current_device())
|
||||
p_mask = batch["attention_mask"].to(torch.cuda.current_device())
|
||||
labels = batch["labels"].to(torch.cuda.current_device())
|
||||
# prompt_ids = prompt_ids.squeeze(1).cuda()
|
||||
# p_mask = p_mask.squeeze(1).cuda()
|
||||
# prompt_logits = self.model(prompt_ids, attention_mask=p_mask, labels=labels)
|
||||
|
||||
outputs = self.model(prompt_ids, attention_mask=p_mask, labels=labels)
|
||||
batch = to_device(batch, torch.cuda.current_device())
|
||||
outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
|
||||
|
||||
loss = outputs.loss
|
||||
prompt_logits = outputs.logits
|
||||
|
||||
if loss >= 2.5 and is_rank_0():
|
||||
logger.warning(f"batch_id:{batch_id}, abnormal loss: {loss}")
|
||||
|
@ -135,18 +119,14 @@ class SFTTrainer(Trainer):
|
|||
loss_sum = 0
|
||||
num_seen = 0
|
||||
for batch in self.eval_dataloader:
|
||||
prompt_ids = batch["input_ids"].to(torch.cuda.current_device())
|
||||
p_mask = batch["attention_mask"].to(torch.cuda.current_device())
|
||||
labels = batch["labels"].to(torch.cuda.current_device())
|
||||
# prompt_ids = prompt_ids.squeeze(1).cuda()
|
||||
# p_mask = p_mask.squeeze(1).cuda()
|
||||
|
||||
outputs = self.model(prompt_ids, attention_mask=p_mask, labels=labels)
|
||||
batch = to_device(batch, torch.cuda.current_device())
|
||||
outputs = self.model(batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
labels=batch["labels"])
|
||||
loss = outputs.loss
|
||||
# prompt_logits = outputs.logits
|
||||
|
||||
loss_sum += loss.item()
|
||||
num_seen += prompt_ids.size(0)
|
||||
num_seen += batch["input_ids"].size(0)
|
||||
|
||||
loss_mean = loss_sum / num_seen
|
||||
if dist.get_rank() == 0:
|
||||
|
|
|
@ -5,7 +5,7 @@ from typing import Any, List, Optional, Tuple, Union
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from coati.models.base import LM, Actor, Critic, RewardModel
|
||||
from coati.models.base import Actor, Critic, RewardModel
|
||||
from coati.replay_buffer import ReplayBuffer
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
@ -99,7 +99,7 @@ class Strategy(ABC):
|
|||
Args:
|
||||
model (nn.Module): an actor or a critic
|
||||
"""
|
||||
if isinstance(model, Actor) or isinstance(model, LM):
|
||||
if isinstance(model, Actor):
|
||||
return model.model
|
||||
return model
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from coati.models.base import LM, Actor, RewardModel
|
||||
from coati.models.base import Actor, RewardModel
|
||||
from coati.models.lora import LoraLinear
|
||||
from torch.optim import Optimizer
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
@ -173,10 +173,6 @@ class ColossalAIStrategy(DDPStrategy):
|
|||
# TODO : better way to get torch model from gemini model
|
||||
# to get torch model from gemini model
|
||||
|
||||
for module in unwrapped_model.modules():
|
||||
if isinstance(module, LoraLinear):
|
||||
module.merge_weights = True
|
||||
module.eval()
|
||||
if isinstance(unwrapped_model, RewardModel):
|
||||
state_dict = unwrapped_model.state_dict()
|
||||
if only_rank0 and dist.get_rank() != 0:
|
||||
|
@ -184,8 +180,6 @@ class ColossalAIStrategy(DDPStrategy):
|
|||
torch.save(state_dict, path)
|
||||
else:
|
||||
try:
|
||||
if isinstance(unwrapped_model, LM):
|
||||
unwrapped_model = unwrapped_model.model
|
||||
logger.info(f'Saving model to {path}', ranks=[0])
|
||||
unwrapped_model.save_pretrained(path)
|
||||
logger.info(f'Model saved to {path} Successfully', ranks=[0])
|
||||
|
|
|
@ -1,14 +1,12 @@
|
|||
from typing import Optional
|
||||
|
||||
import os
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from coati.models.base import LM, Actor, RewardModel
|
||||
from coati.models.lora import LoraLinear
|
||||
from coati.models.base import Actor, RewardModel
|
||||
from coati.replay_buffer import ReplayBuffer
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
|
@ -75,15 +73,14 @@ class DDPStrategy(NaiveStrategy):
|
|||
model: DDP = Strategy._unwrap_actor(actor)
|
||||
return model.module
|
||||
|
||||
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
|
||||
def save_model(self,
|
||||
model: nn.Module,
|
||||
path: str,
|
||||
only_rank0: bool = False,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
|
||||
if only_rank0 and dist.get_rank() != 0:
|
||||
return None
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, LoraLinear):
|
||||
module.merge_weights = True
|
||||
module.eval()
|
||||
|
||||
|
||||
if isinstance(model, RewardModel):
|
||||
state_dict = model.state_dict()
|
||||
if only_rank0 and dist.get_rank() != 0:
|
||||
|
@ -91,8 +88,6 @@ class DDPStrategy(NaiveStrategy):
|
|||
torch.save(state_dict, path)
|
||||
else:
|
||||
try:
|
||||
if isinstance(model, LM):
|
||||
model = model.model
|
||||
model.save_pretrained(path)
|
||||
if tokenizer is not None:
|
||||
tokenizer.save_pretrained(path)
|
||||
|
|
|
@ -3,9 +3,8 @@ from typing import Any, Optional
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from coati.models.base import RewardModel
|
||||
from coati.replay_buffer import ReplayBuffer
|
||||
from coati.models.base import LM, RewardModel
|
||||
from coati.models.lora import LoraLinear
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||
|
@ -41,19 +40,16 @@ class NaiveStrategy(Strategy):
|
|||
pin_memory=pin_memory,
|
||||
collate_fn=replay_buffer.collate_fn)
|
||||
|
||||
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
|
||||
for module in model.modules():
|
||||
if isinstance(module, LoraLinear):
|
||||
module.merge_weights = True
|
||||
module.eval()
|
||||
|
||||
def save_model(self,
|
||||
model: nn.Module,
|
||||
path: str,
|
||||
only_rank0: bool = False,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
|
||||
if isinstance(model, RewardModel):
|
||||
state_dict = model.state_dict()
|
||||
torch.save(state_dict, path)
|
||||
else:
|
||||
try:
|
||||
if isinstance(model, LM):
|
||||
model = model.model
|
||||
model.save_pretrained(path)
|
||||
if tokenizer is not None:
|
||||
tokenizer.save_pretrained(path)
|
||||
|
|
|
@ -16,8 +16,6 @@ from typing import Dict
|
|||
|
||||
import transformers
|
||||
|
||||
from ..models.llama.llama_lm import LlamaLM
|
||||
|
||||
DEFAULT_PAD_TOKEN = "[PAD]"
|
||||
DEFAULT_EOS_TOKEN = "</s>"
|
||||
DEFAULT_BOS_TOKEN = "</s>"
|
||||
|
@ -62,9 +60,6 @@ def smart_tokenizer_and_embedding_resize(
|
|||
if tokenizer.pad_token is None:
|
||||
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
|
||||
|
||||
if isinstance(model, LlamaLM):
|
||||
model = model.get_base_model()
|
||||
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
if num_new_tokens > 0:
|
||||
|
|
|
@ -31,16 +31,19 @@ torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'bigsci
|
|||
--model 'bloom' --strategy colossalai_zero2 --lora_rank 4\
|
||||
--dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
|
||||
--save_path ${BASE}/output
|
||||
rm -rf ${BASE}/output
|
||||
|
||||
torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'gpt2' \
|
||||
--model 'gpt2' --strategy colossalai_zero2 \
|
||||
--dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
|
||||
--save_path ${BASE}/output
|
||||
rm -rf ${BASE}/output
|
||||
|
||||
torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'facebook/opt-350m' \
|
||||
--model 'opt' --strategy colossalai_zero2 --lora_rank 4\
|
||||
--dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
|
||||
--save_path ${BASE}/output
|
||||
rm -rf ${BASE}/output
|
||||
|
||||
torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'gpt2' \
|
||||
--model 'gpt2' --strategy ddp --lora_rank 4\
|
||||
|
@ -59,14 +62,14 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
|||
--pretrain 'facebook/opt-350m' --model 'opt' \
|
||||
--strategy colossalai_zero2 --loss_fn 'log_sig'\
|
||||
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \
|
||||
--test True --lora_rank 4 \
|
||||
--test True --lora_rank 0 \
|
||||
--save_path ${BASE}/rm_ckpt_opt.pt
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
||||
--pretrain 'gpt2' --model 'gpt2' \
|
||||
--strategy colossalai_zero2 --loss_fn 'log_exp' \
|
||||
--dataset 'Dahoas/rm-static' \
|
||||
--test True --lora_rank 4 \
|
||||
--test True --lora_rank 0 \
|
||||
--save_path ${BASE}/rm_ckpt_gpt.pt
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
||||
|
@ -75,6 +78,7 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
|||
--dataset 'Dahoas/rm-static' \
|
||||
--test True --lora_rank 4 \
|
||||
--save_path ${BASE}/rm_ckpt.pt
|
||||
rm -rf ${BASE}/rm_ckpt.pt
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
||||
--pretrain 'bigscience/bloom-560m' --model 'bloom' \
|
||||
|
@ -82,6 +86,7 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
|||
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \
|
||||
--test True --lora_rank 4 \
|
||||
--save_path ${BASE}/rm_ckpt.pt
|
||||
rm -rf ${BASE}/rm_ckpt.pt
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
||||
--pretrain 'microsoft/deberta-v3-large' --model 'deberta' \
|
||||
|
@ -89,6 +94,7 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
|||
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \
|
||||
--test True --lora_rank 4 \
|
||||
--save_path ${BASE}/rm_ckpt.pt
|
||||
rm -rf ${BASE}/rm_ckpt.pt
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
||||
--pretrain 'roberta-base' --model 'roberta' \
|
||||
|
@ -117,4 +123,4 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py --prompt_datas
|
|||
--save_path ${BASE}/actor_checkpoint_prompts.pt
|
||||
rm -rf ${BASE}/rm_ckpt_gpt.pt
|
||||
|
||||
rm -rf ${BASE}/actor_checkpoint_prompts.pt
|
||||
rm -rf ${BASE}/actor_checkpoint_prompts.pt
|
||||
|
|
|
@ -5,11 +5,7 @@ import loralib as lora
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from coati.dataset import DataCollatorForSupervisedDataset, SFTDataset, SupervisedDataset
|
||||
from coati.models.base import RewardModel
|
||||
from coati.models.bloom import BLOOMLM
|
||||
from coati.models.gpt import GPTLM
|
||||
from coati.models.llama import LlamaLM
|
||||
from coati.models.opt import OPTLM
|
||||
from coati.models import convert_to_lora_module
|
||||
from coati.trainer import SFTTrainer
|
||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||
from coati.utils import prepare_llama_tokenizer_and_embedding
|
||||
|
@ -17,8 +13,12 @@ from datasets import load_dataset
|
|||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast
|
||||
from transformers import AutoTokenizer, BloomConfig, BloomForCausalLM, BloomTokenizerFast, LlamaConfig, LlamaForCausalLM
|
||||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||
from transformers.models.opt.configuration_opt import OPTConfig
|
||||
from transformers.models.opt.modeling_opt import OPTForCausalLM
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
@ -32,6 +32,8 @@ def train(args):
|
|||
elif args.strategy == 'ddp':
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == 'colossalai_gemini':
|
||||
raise NotImplementedError(
|
||||
'Gemini is not supported .from_pretrained() yet. We will update this after checkpoint io is ready.')
|
||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
|
||||
elif args.strategy == 'colossalai_zero2':
|
||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
||||
|
@ -43,16 +45,19 @@ def train(args):
|
|||
# configure model
|
||||
with strategy.model_init_context():
|
||||
if args.model == 'bloom':
|
||||
model = BLOOMLM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||
model = convert_to_lora_module(BloomForCausalLM.from_pretrained(args.pretrain),
|
||||
args.lora_rank).half().cuda()
|
||||
elif args.model == 'opt':
|
||||
model = OPTLM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||
model = convert_to_lora_module(OPTForCausalLM.from_pretrained(args.pretrain), args.lora_rank).half().cuda()
|
||||
elif args.model == 'gpt2':
|
||||
model = GPTLM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||
model = convert_to_lora_module(GPT2LMHeadModel.from_pretrained(args.pretrain), args.lora_rank).half().cuda()
|
||||
elif args.model == 'llama':
|
||||
model = LlamaLM(pretrained=args.pretrain, lora_rank=args.lora_rank,
|
||||
checkpoint=True).to(torch.float16).to(torch.cuda.current_device())
|
||||
model = convert_to_lora_module(LlamaForCausalLM.from_pretrained(args.pretrain),
|
||||
args.lora_rank).half().cuda()
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
if args.grad_checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
# configure tokenizer
|
||||
if args.model == 'gpt2':
|
||||
|
@ -152,7 +157,6 @@ def train(args):
|
|||
optim=optim,
|
||||
train_dataloader=train_dataloader,
|
||||
eval_dataloader=eval_dataloader,
|
||||
batch_size=args.batch_size,
|
||||
max_epochs=args.max_epochs,
|
||||
accimulation_steps=args.accimulation_steps)
|
||||
|
||||
|
@ -186,5 +190,6 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--lr', type=float, default=5e-6)
|
||||
parser.add_argument('--accimulation_steps', type=int, default=8)
|
||||
parser.add_argument('--use_wandb', default=False, action='store_true')
|
||||
parser.add_argument('--grad_checkpoint', default=False, action='store_true')
|
||||
args = parser.parse_args()
|
||||
train(args)
|
||||
|
|
Loading…
Reference in New Issue