[chat] remove lm model class (#3653)

* [chat] refactor lora

* [chat] remove lm class

* [chat] refactor save model

* [chat] refactor train sft

* [chat] fix ci

* [chat] fix ci
pull/3662/head
Hongxin Liu 2023-04-27 15:37:38 +08:00 committed by GitHub
parent 8bccb72c8d
commit 6ef7011462
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 84 additions and 284 deletions

View File

@ -1,4 +1,8 @@
from .base import Actor, Critic, RewardModel
from .lora import LoRAModule, convert_to_lora_module
from .loss import LogExpLoss, LogSigLoss, PolicyLoss, PPOPtxActorLoss, ValueLoss
__all__ = ['Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'LogSigLoss', 'LogExpLoss']
__all__ = [
'Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'LogSigLoss', 'LogExpLoss',
'LoRAModule', 'convert_to_lora_module'
]

View File

@ -1,6 +1,5 @@
from .actor import Actor
from .critic import Critic
from .lm import LM
from .reward_model import RewardModel
__all__ = ['Actor', 'Critic', 'RewardModel', 'LM']
__all__ = ['Actor', 'Critic', 'RewardModel']

View File

@ -1,30 +0,0 @@
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..generation import generate
from .actor import Actor
class LM(Actor):
"""
Language model base class.
Args:
model (nn.Module): Language Model.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str = 'none') -> None:
super().__init__(model=model, lora_rank=lora_rank, lora_train_bias=lora_train_bias)
def forward(self, sequences: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Returns output log probs
"""
output = self.model(sequences, attention_mask=attention_mask)
logits = output['logits']
log_probs = F.log_softmax(logits, dim=-1)
return log_probs

View File

@ -1,6 +1,5 @@
from .bloom_actor import BLOOMActor
from .bloom_critic import BLOOMCritic
from .bloom_lm import BLOOMLM
from .bloom_rm import BLOOMRM
__all__ = ['BLOOMActor', 'BLOOMCritic', 'BLOOMRM', 'BLOOMLM']
__all__ = ['BLOOMActor', 'BLOOMCritic', 'BLOOMRM']

View File

@ -1,38 +0,0 @@
from typing import Optional
import torch
from transformers import BloomConfig, BloomForCausalLM, BloomModel
from ..base import LM
class BLOOMLM(LM):
"""
BLOOM language model.
Args:
pretrained (str): Pretrained model name or path.
config (BloomConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: str = None,
config: Optional[BloomConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = BloomForCausalLM.from_pretrained(pretrained)
elif config is not None:
model = BloomForCausalLM(config)
else:
model = BloomForCausalLM(BloomConfig())
if checkpoint:
model.gradient_checkpointing_enable()
super().__init__(model, lora_rank, lora_train_bias)
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs)

View File

@ -1,6 +1,5 @@
from .gpt_actor import GPTActor
from .gpt_critic import GPTCritic
from .gpt_lm import GPTLM
from .gpt_rm import GPTRM
__all__ = ['GPTActor', 'GPTCritic', 'GPTRM', 'GPTLM']
__all__ = ['GPTActor', 'GPTCritic', 'GPTRM']

View File

@ -1,38 +0,0 @@
from typing import Optional
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
from ..base import LM
class GPTLM(LM):
"""
GPT language model.
Args:
pretrained (str): Pretrained model name or path.
config (GPT2Config): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the LoRa layer.
lora_train_bias (str): Bias training strategy for the LoRa layer.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[GPT2Config] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = GPT2LMHeadModel.from_pretrained(pretrained)
elif config is not None:
model = GPT2LMHeadModel(config)
else:
model = GPT2LMHeadModel(GPT2Config())
if checkpoint:
model.gradient_checkpointing_enable()
super().__init__(model, lora_rank, lora_train_bias)
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs)

View File

@ -1,6 +1,5 @@
from .llama_actor import LlamaActor
from .llama_critic import LlamaCritic
from .llama_lm import LlamaLM
from .llama_rm import LlamaRM
__all__ = ['LlamaActor', 'LlamaCritic', 'LlamaRM', 'LlamaLM']
__all__ = ['LlamaActor', 'LlamaCritic', 'LlamaRM']

View File

@ -1,40 +0,0 @@
from typing import Optional
from transformers import LlamaConfig, LlamaForCausalLM
from ..base import LM
class LlamaLM(LM):
"""
Llama language model.
Args:
pretrained (str): Pretrained model name or path.
config (LlamaConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[LlamaConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = LlamaForCausalLM.from_pretrained(pretrained)
elif config is not None:
model = LlamaForCausalLM(config)
else:
model = LlamaForCausalLM(LlamaConfig())
if checkpoint:
model.gradient_checkpointing_enable()
super().__init__(model, lora_rank, lora_train_bias)
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs)

View File

@ -106,6 +106,23 @@ def convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
convert_to_lora_recursively(child, lora_rank)
def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = 'none') -> nn.Module:
"""Convert a torch.nn.Module to a LoRA module.
Args:
module (nn.Module): The module to convert.
lora_rank (int): LoRA rank.
Returns:
nn.Module: The converted module.
"""
if lora_rank <= 0:
return module
convert_to_lora_recursively(module, lora_rank)
lora.mark_only_lora_as_trainable(module, lora_train_bias)
return module
class LoRAModule(nn.Module):
"""A LoRA module base class. All derived classes should call `convert_to_lora()` at the bottom of `__init__()`.
This class will convert all torch.nn.Linear layer to LoraLinear layer.
@ -123,7 +140,4 @@ class LoRAModule(nn.Module):
self.lora_train_bias = lora_train_bias
def convert_to_lora(self) -> None:
if self.lora_rank <= 0:
return
convert_to_lora_recursively(self, self.lora_rank)
lora.mark_only_lora_as_trainable(self, self.lora_train_bias)
convert_to_lora_module(self, self.lora_rank, self.lora_train_bias)

View File

@ -1,6 +1,5 @@
from .opt_actor import OPTActor
from .opt_critic import OPTCritic
from .opt_lm import OPTLM
from .opt_rm import OPTRM
__all__ = ['OPTActor', 'OPTCritic', 'OPTRM', 'OPTLM']
__all__ = ['OPTActor', 'OPTCritic', 'OPTRM']

View File

@ -1,38 +0,0 @@
from typing import Optional
from transformers.models.opt.configuration_opt import OPTConfig
from transformers.models.opt.modeling_opt import OPTForCausalLM
from ..base import LM
class OPTLM(LM):
"""
OPT language model.
Args:
pretrained (str): Pretrained model name or path.
config (OPTConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[OPTConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = OPTForCausalLM.from_pretrained(pretrained)
elif config is not None:
model = OPTForCausalLM(config)
else:
model = OPTForCausalLM(OPTConfig())
if checkpoint:
model.gradient_checkpointing_enable()
super().__init__(model, lora_rank, lora_train_bias)
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs)

View File

@ -2,26 +2,19 @@ import math
import time
from typing import List, Optional
import loralib as lora
import torch
import torch.distributed as dist
import wandb
from coati.models.loss import GPTLMLoss
from torch import nn
from torch.optim import Adam, Optimizer
from torch.optim.lr_scheduler import LambdaLR
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer import get_scheduler
from colossalai.logging import get_dist_logger
from .base import Trainer
from .callbacks import Callback
from .strategies import Strategy
from .utils import is_rank_0
from .strategies import ColossalAIStrategy, Strategy
from .utils import is_rank_0, to_device
class SFTTrainer(Trainer):
@ -47,19 +40,17 @@ class SFTTrainer(Trainer):
optim: Optimizer,
train_dataloader: DataLoader,
eval_dataloader: DataLoader = None,
batch_size: int = 1,
max_epochs: int = 2,
accimulation_steps: int = 8,
callbacks: List[Callback] = [],
) -> None:
if accimulation_steps > 1 and isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3:
raise ValueError("Accumulation steps are not supported in stage 3 of ColossalAI")
super().__init__(strategy, max_epochs, callbacks=callbacks)
self.train_dataloader = train_dataloader
self.eval_dataloader = eval_dataloader
self.model = strategy.setup_model(model)
if "DDP" in str(self.strategy):
self.model = self.model.module
self.optimizer = strategy.setup_optimizer(optim, self.model)
(self.model, self.optimizer) = strategy.prepare((model, optim))
self.accimulation_steps = accimulation_steps
num_update_steps_per_epoch = len(train_dataloader) // self.accimulation_steps
@ -86,17 +77,10 @@ class SFTTrainer(Trainer):
self.model.train()
for batch_id, batch in enumerate(self.train_dataloader):
prompt_ids = batch["input_ids"].to(torch.cuda.current_device())
p_mask = batch["attention_mask"].to(torch.cuda.current_device())
labels = batch["labels"].to(torch.cuda.current_device())
# prompt_ids = prompt_ids.squeeze(1).cuda()
# p_mask = p_mask.squeeze(1).cuda()
# prompt_logits = self.model(prompt_ids, attention_mask=p_mask, labels=labels)
outputs = self.model(prompt_ids, attention_mask=p_mask, labels=labels)
batch = to_device(batch, torch.cuda.current_device())
outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
loss = outputs.loss
prompt_logits = outputs.logits
if loss >= 2.5 and is_rank_0():
logger.warning(f"batch_id:{batch_id}, abnormal loss: {loss}")
@ -135,18 +119,14 @@ class SFTTrainer(Trainer):
loss_sum = 0
num_seen = 0
for batch in self.eval_dataloader:
prompt_ids = batch["input_ids"].to(torch.cuda.current_device())
p_mask = batch["attention_mask"].to(torch.cuda.current_device())
labels = batch["labels"].to(torch.cuda.current_device())
# prompt_ids = prompt_ids.squeeze(1).cuda()
# p_mask = p_mask.squeeze(1).cuda()
outputs = self.model(prompt_ids, attention_mask=p_mask, labels=labels)
batch = to_device(batch, torch.cuda.current_device())
outputs = self.model(batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"])
loss = outputs.loss
# prompt_logits = outputs.logits
loss_sum += loss.item()
num_seen += prompt_ids.size(0)
num_seen += batch["input_ids"].size(0)
loss_mean = loss_sum / num_seen
if dist.get_rank() == 0:

View File

@ -5,7 +5,7 @@ from typing import Any, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from coati.models.base import LM, Actor, Critic, RewardModel
from coati.models.base import Actor, Critic, RewardModel
from coati.replay_buffer import ReplayBuffer
from torch.optim import Optimizer
from torch.utils.data import DataLoader
@ -99,7 +99,7 @@ class Strategy(ABC):
Args:
model (nn.Module): an actor or a critic
"""
if isinstance(model, Actor) or isinstance(model, LM):
if isinstance(model, Actor):
return model.model
return model

View File

@ -5,7 +5,7 @@ import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from coati.models.base import LM, Actor, RewardModel
from coati.models.base import Actor, RewardModel
from coati.models.lora import LoraLinear
from torch.optim import Optimizer
from transformers.modeling_utils import PreTrainedModel
@ -173,10 +173,6 @@ class ColossalAIStrategy(DDPStrategy):
# TODO : better way to get torch model from gemini model
# to get torch model from gemini model
for module in unwrapped_model.modules():
if isinstance(module, LoraLinear):
module.merge_weights = True
module.eval()
if isinstance(unwrapped_model, RewardModel):
state_dict = unwrapped_model.state_dict()
if only_rank0 and dist.get_rank() != 0:
@ -184,8 +180,6 @@ class ColossalAIStrategy(DDPStrategy):
torch.save(state_dict, path)
else:
try:
if isinstance(unwrapped_model, LM):
unwrapped_model = unwrapped_model.model
logger.info(f'Saving model to {path}', ranks=[0])
unwrapped_model.save_pretrained(path)
logger.info(f'Model saved to {path} Successfully', ranks=[0])

View File

@ -1,14 +1,12 @@
from typing import Optional
import os
import random
from typing import Optional
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from coati.models.base import LM, Actor, RewardModel
from coati.models.lora import LoraLinear
from coati.models.base import Actor, RewardModel
from coati.replay_buffer import ReplayBuffer
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
@ -75,15 +73,14 @@ class DDPStrategy(NaiveStrategy):
model: DDP = Strategy._unwrap_actor(actor)
return model.module
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
def save_model(self,
model: nn.Module,
path: str,
only_rank0: bool = False,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
if only_rank0 and dist.get_rank() != 0:
return None
for module in model.modules():
if isinstance(module, LoraLinear):
module.merge_weights = True
module.eval()
if isinstance(model, RewardModel):
state_dict = model.state_dict()
if only_rank0 and dist.get_rank() != 0:
@ -91,8 +88,6 @@ class DDPStrategy(NaiveStrategy):
torch.save(state_dict, path)
else:
try:
if isinstance(model, LM):
model = model.model
model.save_pretrained(path)
if tokenizer is not None:
tokenizer.save_pretrained(path)

View File

@ -3,9 +3,8 @@ from typing import Any, Optional
import torch
import torch.nn as nn
import torch.optim as optim
from coati.models.base import RewardModel
from coati.replay_buffer import ReplayBuffer
from coati.models.base import LM, RewardModel
from coati.models.lora import LoraLinear
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
@ -41,19 +40,16 @@ class NaiveStrategy(Strategy):
pin_memory=pin_memory,
collate_fn=replay_buffer.collate_fn)
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
for module in model.modules():
if isinstance(module, LoraLinear):
module.merge_weights = True
module.eval()
def save_model(self,
model: nn.Module,
path: str,
only_rank0: bool = False,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
if isinstance(model, RewardModel):
state_dict = model.state_dict()
torch.save(state_dict, path)
else:
try:
if isinstance(model, LM):
model = model.model
model.save_pretrained(path)
if tokenizer is not None:
tokenizer.save_pretrained(path)

View File

@ -16,8 +16,6 @@ from typing import Dict
import transformers
from ..models.llama.llama_lm import LlamaLM
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "</s>"
@ -62,9 +60,6 @@ def smart_tokenizer_and_embedding_resize(
if tokenizer.pad_token is None:
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
if isinstance(model, LlamaLM):
model = model.get_base_model()
model.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:

View File

@ -31,16 +31,19 @@ torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'bigsci
--model 'bloom' --strategy colossalai_zero2 --lora_rank 4\
--dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
--save_path ${BASE}/output
rm -rf ${BASE}/output
torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'gpt2' \
--model 'gpt2' --strategy colossalai_zero2 \
--dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
--save_path ${BASE}/output
rm -rf ${BASE}/output
torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'facebook/opt-350m' \
--model 'opt' --strategy colossalai_zero2 --lora_rank 4\
--dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
--save_path ${BASE}/output
rm -rf ${BASE}/output
torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'gpt2' \
--model 'gpt2' --strategy ddp --lora_rank 4\
@ -59,14 +62,14 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
--pretrain 'facebook/opt-350m' --model 'opt' \
--strategy colossalai_zero2 --loss_fn 'log_sig'\
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \
--test True --lora_rank 4 \
--test True --lora_rank 0 \
--save_path ${BASE}/rm_ckpt_opt.pt
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
--pretrain 'gpt2' --model 'gpt2' \
--strategy colossalai_zero2 --loss_fn 'log_exp' \
--dataset 'Dahoas/rm-static' \
--test True --lora_rank 4 \
--test True --lora_rank 0 \
--save_path ${BASE}/rm_ckpt_gpt.pt
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
@ -75,6 +78,7 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
--dataset 'Dahoas/rm-static' \
--test True --lora_rank 4 \
--save_path ${BASE}/rm_ckpt.pt
rm -rf ${BASE}/rm_ckpt.pt
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
--pretrain 'bigscience/bloom-560m' --model 'bloom' \
@ -82,6 +86,7 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \
--test True --lora_rank 4 \
--save_path ${BASE}/rm_ckpt.pt
rm -rf ${BASE}/rm_ckpt.pt
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
--pretrain 'microsoft/deberta-v3-large' --model 'deberta' \
@ -89,6 +94,7 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \
--test True --lora_rank 4 \
--save_path ${BASE}/rm_ckpt.pt
rm -rf ${BASE}/rm_ckpt.pt
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
--pretrain 'roberta-base' --model 'roberta' \
@ -117,4 +123,4 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py --prompt_datas
--save_path ${BASE}/actor_checkpoint_prompts.pt
rm -rf ${BASE}/rm_ckpt_gpt.pt
rm -rf ${BASE}/actor_checkpoint_prompts.pt
rm -rf ${BASE}/actor_checkpoint_prompts.pt

View File

@ -5,11 +5,7 @@ import loralib as lora
import torch
import torch.distributed as dist
from coati.dataset import DataCollatorForSupervisedDataset, SFTDataset, SupervisedDataset
from coati.models.base import RewardModel
from coati.models.bloom import BLOOMLM
from coati.models.gpt import GPTLM
from coati.models.llama import LlamaLM
from coati.models.opt import OPTLM
from coati.models import convert_to_lora_module
from coati.trainer import SFTTrainer
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from coati.utils import prepare_llama_tokenizer_and_embedding
@ -17,8 +13,12 @@ from datasets import load_dataset
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from transformers import AutoTokenizer, BloomTokenizerFast
from transformers import AutoTokenizer, BloomConfig, BloomForCausalLM, BloomTokenizerFast, LlamaConfig, LlamaForCausalLM
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from transformers.models.opt.configuration_opt import OPTConfig
from transformers.models.opt.modeling_opt import OPTForCausalLM
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import HybridAdam
@ -32,6 +32,8 @@ def train(args):
elif args.strategy == 'ddp':
strategy = DDPStrategy()
elif args.strategy == 'colossalai_gemini':
raise NotImplementedError(
'Gemini is not supported .from_pretrained() yet. We will update this after checkpoint io is ready.')
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
elif args.strategy == 'colossalai_zero2':
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
@ -43,16 +45,19 @@ def train(args):
# configure model
with strategy.model_init_context():
if args.model == 'bloom':
model = BLOOMLM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
model = convert_to_lora_module(BloomForCausalLM.from_pretrained(args.pretrain),
args.lora_rank).half().cuda()
elif args.model == 'opt':
model = OPTLM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
model = convert_to_lora_module(OPTForCausalLM.from_pretrained(args.pretrain), args.lora_rank).half().cuda()
elif args.model == 'gpt2':
model = GPTLM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
model = convert_to_lora_module(GPT2LMHeadModel.from_pretrained(args.pretrain), args.lora_rank).half().cuda()
elif args.model == 'llama':
model = LlamaLM(pretrained=args.pretrain, lora_rank=args.lora_rank,
checkpoint=True).to(torch.float16).to(torch.cuda.current_device())
model = convert_to_lora_module(LlamaForCausalLM.from_pretrained(args.pretrain),
args.lora_rank).half().cuda()
else:
raise ValueError(f'Unsupported model "{args.model}"')
if args.grad_checkpoint:
model.gradient_checkpointing_enable()
# configure tokenizer
if args.model == 'gpt2':
@ -152,7 +157,6 @@ def train(args):
optim=optim,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
batch_size=args.batch_size,
max_epochs=args.max_epochs,
accimulation_steps=args.accimulation_steps)
@ -186,5 +190,6 @@ if __name__ == '__main__':
parser.add_argument('--lr', type=float, default=5e-6)
parser.add_argument('--accimulation_steps', type=int, default=8)
parser.add_argument('--use_wandb', default=False, action='store_true')
parser.add_argument('--grad_checkpoint', default=False, action='store_true')
args = parser.parse_args()
train(args)