[chatgpt] add pre-trained model RoBERTa for RLHF stage 2 & 3 (#3223)

* Add RoBERTa for RLHF Stage 2 & 3 (test)

RoBERTa for RLHF Stage 2 & 3 (still in testing)

* Revert "Add RoBERTa for RLHF Stage 2 & 3 (test)"

This reverts commit 06741d894d.

* Add RoBERTa for RLHF stage 2 & 3

1. add roberta folder under model folder
2. add  roberta option in train_reward_model.py
3. add some test in testci

* add test for reward model training

* Update test_ci.sh

* Revert "Update test_ci.sh"

This reverts commit 9c7352b81766f3177d31eeec0ec178a301df966a.

* Add RoBERTa for RLHF Stage 2 & 3 (test)

RoBERTa for RLHF Stage 2 & 3 (still in testing)

* Revert "Add RoBERTa for RLHF Stage 2 & 3 (test)"

This reverts commit 06741d894d.

* Add RoBERTa for RLHF stage 2 & 3

1. add roberta folder under model folder
2. add  roberta option in train_reward_model.py
3. add some test in testci

* Update test_ci.sh

* Revert "Update test_ci.sh"

This reverts commit 9c7352b81766f3177d31eeec0ec178a301df966a.

* update roberta with coati
pull/3418/head
Camille Zhong 2023-04-03 10:11:03 +08:00 committed by GitHub
parent 94c24d9444
commit 30412866e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 173 additions and 9 deletions

View File

@ -0,0 +1,5 @@
from .roberta_actor import RoBERTaActor
from .roberta_critic import RoBERTaCritic
from .roberta_rm import RoBERTaRM
__all__ = ['RoBERTaActor', 'RoBERTaCritic', 'RoBERTaRM']

View File

@ -0,0 +1,35 @@
from typing import Optional
from transformers.models.roberta.configuration_roberta import RobertaConfig
from transformers.models.roberta.modeling_roberta import RobertaForCausalLM
from ..base import Actor
class RoBERTaActor(Actor):
"""
RoBERTa Actor model.
Args:
pretrained (str): Pretrained model name or path.
config (RoBERTaConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[RobertaConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = RobertaForCausalLM.from_pretrained(pretrained)
elif config is not None:
model = RobertaForCausalLM(config)
else:
model = RobertaForCausalLM(RobertaConfig())
if checkpoint:
model.gradient_checkpointing_enable()
super().__init__(model, lora_rank, lora_train_bias)

View File

@ -0,0 +1,38 @@
from typing import Optional
import torch.nn as nn
from transformers.models.roberta.configuration_roberta import RobertaConfig
from transformers.models.roberta.modeling_roberta import RobertaModel
from ..base import Critic
class RoBERTaCritic(Critic):
"""
RoBERTa Critic model.
Args:
pretrained (str): Pretrained model name or path.
config (RoBERTa Config): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[RobertaConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none',
**kwargs) -> None:
if pretrained is not None:
model = RobertaModel.from_pretrained(pretrained, add_pooling_layer=False)
elif config is not None:
model = RobertaModel(config)
else:
model = RobertaModel(RobertaConfig())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.hidden_size, 1)
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)

View File

@ -0,0 +1,39 @@
from typing import Optional
import torch.nn as nn
from transformers import RobertaConfig, RobertaModel
from ..base import RewardModel
class RoBERTaRM(RewardModel):
"""
RoBERTa Reward model.
Args:
pretrained (str): Pretrained model name or path.
config (RoBERTaConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[RobertaConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = RobertaModel.from_pretrained(pretrained, add_pooling_layer=False)
elif config is not None:
model = RobertaModel(config)
else:
model = RobertaModel(RobertaConfig())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.hidden_size, 1)
value_head.weight.data.normal_(mean=0.0, std=1/(model.config.hidden_size + 1))
super().__init__(model, value_head, lora_rank, lora_train_bias)

View File

@ -4,7 +4,8 @@ import torch
from coati.models.bloom import BLOOMActor from coati.models.bloom import BLOOMActor
from coati.models.gpt import GPTActor from coati.models.gpt import GPTActor
from coati.models.opt import OPTActor from coati.models.opt import OPTActor
from transformers import AutoTokenizer from coati.models.roberta import RoBERTaActor
from transformers import AutoTokenizer, RobertaTokenizer
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
@ -16,6 +17,8 @@ def eval(args):
actor = BLOOMActor(pretrained=args.pretrain).to(torch.cuda.current_device()) actor = BLOOMActor(pretrained=args.pretrain).to(torch.cuda.current_device())
elif args.model == 'opt': elif args.model == 'opt':
actor = OPTActor(pretrained=args.pretrain).to(torch.cuda.current_device()) actor = OPTActor(pretrained=args.pretrain).to(torch.cuda.current_device())
elif args.model == 'roberta':
actor = RoBERTaActor(pretrained=args.pretrain).to(torch.cuda.current_device())
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
@ -31,6 +34,8 @@ def eval(args):
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'opt': elif args.model == 'opt':
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m') tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m')
elif args.model == 'roberta':
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
@ -49,7 +54,7 @@ def eval(args):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt']) parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'roberta'])
# We suggest to use the pretrained model from HuggingFace, use pretrain to configure model # We suggest to use the pretrained model from HuggingFace, use pretrain to configure model
parser.add_argument('--pretrain', type=str, default=None) parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--model_path', type=str, default=None) parser.add_argument('--model_path', type=str, default=None)

View File

@ -40,6 +40,13 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_dummy.py \
--save_path ${BASE}/actor_checkpoint_dummy.pt --save_path ${BASE}/actor_checkpoint_dummy.pt
python ${BASE}/inference.py --model_path ${BASE}/actor_checkpoint_dummy.pt --pretrain 'gpt2' --model gpt2 python ${BASE}/inference.py --model_path ${BASE}/actor_checkpoint_dummy.pt --pretrain 'gpt2' --model gpt2
torchrun --standalone --nproc_per_node=2 ${BASE}/train_dummy.py \
--strategy colossalai_zero2 --num_episodes 1 --max_timesteps 2 \
--update_timesteps 2 --max_epochs 1 --train_batch_size 2\
--pretrain 'roberta-base' --model roberta --lora_rank 4\
--save_path ${BASE}/actor_checkpoint_dummy.pt
python ${BASE}/inference.py --model_path ${BASE}/actor_checkpoint_dummy.pt --pretrain 'roberta-base' --model roberta
rm -rf ${BASE}/actor_checkpoint_dummy.pt rm -rf ${BASE}/actor_checkpoint_dummy.pt
# train prompts # train prompts
@ -68,6 +75,13 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py $PROMPT_PATH \
--save_path ${BASE}/actor_checkpoint_prompts.pt --save_path ${BASE}/actor_checkpoint_prompts.pt
python ${BASE}/inference.py --model_path ${BASE}/actor_checkpoint_prompts.pt --pretrain 'gpt2' --model gpt2 python ${BASE}/inference.py --model_path ${BASE}/actor_checkpoint_prompts.pt --pretrain 'gpt2' --model gpt2
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py $PROMPT_PATH \
--strategy colossalai_zero2 --num_episodes 1 --max_timesteps 2 \
--update_timesteps 2 --max_epochs 1 --train_batch_size 2\
--pretrain 'roberta-base' --model roberta --lora_rank 4\
--save_path ${BASE}/actor_checkpoint_prompts.pt
python ${BASE}/inference.py --model_path ${BASE}/actor_checkpoint_prompts.pt --pretrain 'roberta-base' --model roberta
rm -rf ${BASE}/actor_checkpoint_prompts.pt rm -rf ${BASE}/actor_checkpoint_prompts.pt
# train rm # train rm
@ -94,4 +108,10 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base'\ --dataset 'Anthropic/hh-rlhf' --subset 'harmless-base'\
--test True --lora_rank 4 --test True --lora_rank 4
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
--pretrain 'roberta-base' --model 'roberta' \
--strategy colossalai_zero2 --loss_fn 'log_exp'\
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base'\
--test True --lora_rank 4
rm -rf ${BASE}/rm_ckpt.pt rm -rf ${BASE}/rm_ckpt.pt

View File

@ -6,11 +6,12 @@ from coati.models.base import RewardModel
from coati.models.bloom import BLOOMActor, BLOOMCritic from coati.models.bloom import BLOOMActor, BLOOMCritic
from coati.models.gpt import GPTActor, GPTCritic from coati.models.gpt import GPTActor, GPTCritic
from coati.models.opt import OPTActor, OPTCritic from coati.models.opt import OPTActor, OPTCritic
from coati.models.roberta import RoBERTaActor, RoBERTaCritic
from coati.trainer import PPOTrainer from coati.trainer import PPOTrainer
from coati.trainer.callbacks import SaveCheckpoint from coati.trainer.callbacks import SaveCheckpoint
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from torch.optim import Adam from torch.optim import Adam
from transformers import AutoTokenizer, BloomTokenizerFast from transformers import AutoTokenizer, BloomTokenizerFast, RobertaTokenizer
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
@ -46,6 +47,9 @@ def main(args):
elif args.model == 'opt': elif args.model == 'opt':
actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
critic = OPTCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) critic = OPTCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
elif args.model == 'roberta':
actor = RoBERTaActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
critic = RoBERTaCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
@ -69,6 +73,8 @@ def main(args):
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'opt': elif args.model == 'opt':
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
elif args.model == 'roberta':
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
@ -128,7 +134,7 @@ if __name__ == '__main__':
parser.add_argument('--strategy', parser.add_argument('--strategy',
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
default='naive') default='naive')
parser.add_argument('--model', type=str, default='gpt2', choices=['gpt2', 'bloom', 'opt']) parser.add_argument('--model', type=str, default='gpt2', choices=['gpt2', 'bloom', 'opt', 'roberta'])
parser.add_argument('--pretrain', type=str, default=None) parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--save_path', type=str, default='actor_checkpoint_dummy.pt') parser.add_argument('--save_path', type=str, default='actor_checkpoint_dummy.pt')
parser.add_argument('--need_optim_ckpt', type=bool, default=False) parser.add_argument('--need_optim_ckpt', type=bool, default=False)

View File

@ -8,13 +8,14 @@ from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
from coati.models.gpt import GPTRM, GPTActor, GPTCritic from coati.models.gpt import GPTRM, GPTActor, GPTCritic
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
from coati.models.opt import OPTRM, OPTActor, OPTCritic from coati.models.opt import OPTRM, OPTActor, OPTCritic
from coati.models.roberta import RoBERTaRM, RoBERTaActor, RoBERTaCritic
from coati.trainer import PPOTrainer from coati.trainer import PPOTrainer
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from coati.utils import prepare_llama_tokenizer_and_embedding from coati.utils import prepare_llama_tokenizer_and_embedding
from torch.optim import Adam from torch.optim import Adam
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer, RobertaTokenizer
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
@ -44,6 +45,8 @@ def main(args):
initial_model = OPTActor(pretrained=args.pretrain) initial_model = OPTActor(pretrained=args.pretrain)
elif args.model == 'llama': elif args.model == 'llama':
initial_model = LlamaActor(pretrained=args.pretrain) initial_model = LlamaActor(pretrained=args.pretrain)
elif args.model == 'roberta':
initial_model = RoBERTaActor(pretrained=args.pretrain)
else: else:
raise ValueError(f'Unsupported actor model "{args.model}"') raise ValueError(f'Unsupported actor model "{args.model}"')
@ -60,6 +63,8 @@ def main(args):
reward_model = OPTRM(pretrained=args.rm_pretrain) reward_model = OPTRM(pretrained=args.rm_pretrain)
elif rm_model_name == 'llama': elif rm_model_name == 'llama':
reward_model = LlamaRM(pretrained=args.rm_pretrain) reward_model = LlamaRM(pretrained=args.rm_pretrain)
elif rm_model_name == 'roberta':
reward_model = RoBERTaRM(pretrained=args.rm_pretrain)
else: else:
raise ValueError(f'Unsupported reward model "{rm_model_name}"') raise ValueError(f'Unsupported reward model "{rm_model_name}"')
@ -79,6 +84,8 @@ def main(args):
actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank) actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
elif args.model == 'llama': elif args.model == 'llama':
actor = LlamaActor(pretrained=args.pretrain, lora_rank=args.lora_rank) actor = LlamaActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
elif args.model == 'roberta':
actor = RoBERTaActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
else: else:
raise ValueError(f'Unsupported actor model "{args.model}"') raise ValueError(f'Unsupported actor model "{args.model}"')
@ -90,6 +97,8 @@ def main(args):
critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
elif rm_model_name == 'llama': elif rm_model_name == 'llama':
critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
elif rm_model_name == 'roberta':
critic = RoBERTaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
else: else:
raise ValueError(f'Unsupported reward model "{rm_model_name}"') raise ValueError(f'Unsupported reward model "{rm_model_name}"')
@ -119,6 +128,8 @@ def main(args):
elif args.model == 'llama': elif args.model == 'llama':
tokenizer = LlamaTokenizer.from_pretrained(args.pretrain) tokenizer = LlamaTokenizer.from_pretrained(args.pretrain)
tokenizer.eos_token = '<\s>' tokenizer.eos_token = '<\s>'
elif args.model == 'roberta':
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
@ -200,9 +211,9 @@ if __name__ == '__main__':
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
default='naive', default='naive',
help='strategy to use') help='strategy to use')
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama', 'roberta'])
parser.add_argument('--pretrain', type=str, default=None) parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--rm_model', default=None, choices=['gpt2', 'bloom', 'opt', 'llama']) parser.add_argument('--rm_model', default=None, choices=['gpt2', 'bloom', 'opt', 'llama', 'roberta'])
parser.add_argument('--rm_path', type=str, default=None) parser.add_argument('--rm_path', type=str, default=None)
parser.add_argument('--rm_pretrain', type=str, default=None) parser.add_argument('--rm_pretrain', type=str, default=None)
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts') parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts')

View File

@ -11,12 +11,13 @@ from coati.models.deberta import DebertaRM
from coati.models.gpt import GPTRM from coati.models.gpt import GPTRM
from coati.models.llama import LlamaRM from coati.models.llama import LlamaRM
from coati.models.opt import OPTRM from coati.models.opt import OPTRM
from coati.models.roberta import RoBERTaRM
from coati.trainer import RewardModelTrainer from coati.trainer import RewardModelTrainer
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from coati.utils import prepare_llama_tokenizer_and_embedding from coati.utils import prepare_llama_tokenizer_and_embedding
from datasets import load_dataset from datasets import load_dataset
from torch.optim import Adam from torch.optim import Adam
from transformers import AutoTokenizer, BloomTokenizerFast, DebertaV2Tokenizer, LlamaTokenizer from transformers import AutoTokenizer, BloomTokenizerFast, DebertaV2Tokenizer, LlamaTokenizer, RobertaTokenizer
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
@ -47,6 +48,8 @@ def train(args):
model = DebertaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) model = DebertaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
elif args.model == 'llama': elif args.model == 'llama':
model = LlamaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) model = LlamaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
elif args.model == 'roberta':
model = RoBERTaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
@ -67,6 +70,8 @@ def train(args):
tokenizer = DebertaV2Tokenizer.from_pretrained('microsoft/deberta-v3-large') tokenizer = DebertaV2Tokenizer.from_pretrained('microsoft/deberta-v3-large')
elif args.model == 'llama': elif args.model == 'llama':
tokenizer = LlamaTokenizer.from_pretrained(args.pretrain) tokenizer = LlamaTokenizer.from_pretrained(args.pretrain)
elif args.model == 'roberta':
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
max_len = args.max_len max_len = args.max_len
@ -140,7 +145,7 @@ if __name__ == '__main__':
parser.add_argument('--strategy', parser.add_argument('--strategy',
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
default='naive') default='naive')
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'deberta', 'llama'], default='bloom') parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'deberta', 'llama', 'roberta'], default='bloom')
parser.add_argument('--pretrain', type=str, default=None) parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--model_path', type=str, default=None) parser.add_argument('--model_path', type=str, default=None)
parser.add_argument('--need_optim_ckpt', type=bool, default=False) parser.add_argument('--need_optim_ckpt', type=bool, default=False)