mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
153 lines
4.2 KiB
153 lines
4.2 KiB
import colossalai
|
|
from numpy import require
|
|
|
|
__all__ = ['parse_args']
|
|
|
|
|
|
def parse_args():
|
|
parser = colossalai.get_default_parser()
|
|
|
|
parser.add_argument(
|
|
'--lr',
|
|
type=float,
|
|
required=True,
|
|
help='initial learning rate')
|
|
parser.add_argument(
|
|
'--epoch',
|
|
type=int,
|
|
required=True,
|
|
help='number of epoch')
|
|
parser.add_argument(
|
|
'--data_path_prefix',
|
|
type=str,
|
|
required=True,
|
|
help="location of the train data corpus")
|
|
parser.add_argument(
|
|
'--eval_data_path_prefix',
|
|
type=str,
|
|
required=True,
|
|
help='location of the evaluation data corpus')
|
|
parser.add_argument(
|
|
'--tokenizer_path',
|
|
type=str,
|
|
required=True,
|
|
help='location of the tokenizer')
|
|
parser.add_argument(
|
|
'--max_seq_length',
|
|
type=int,
|
|
default=512,
|
|
help='sequence length')
|
|
parser.add_argument(
|
|
'--refresh_bucket_size',
|
|
type=int,
|
|
default=1,
|
|
help=
|
|
"This param makes sure that a certain task is repeated for this time steps to \
|
|
optimise on the back propogation speed with APEX's DistributedDataParallel")
|
|
parser.add_argument(
|
|
"--max_predictions_per_seq",
|
|
"--max_pred",
|
|
default=80,
|
|
type=int,
|
|
help=
|
|
"The maximum number of masked tokens in a sequence to be predicted.")
|
|
parser.add_argument(
|
|
"--gradient_accumulation_steps",
|
|
default=1,
|
|
type=int,
|
|
help="accumulation_steps")
|
|
parser.add_argument(
|
|
"--train_micro_batch_size_per_gpu",
|
|
default=2,
|
|
type=int,
|
|
required=True,
|
|
help="train batch size")
|
|
parser.add_argument(
|
|
"--eval_micro_batch_size_per_gpu",
|
|
default=2,
|
|
type=int,
|
|
required=True,
|
|
help="eval batch size")
|
|
parser.add_argument(
|
|
"--num_workers",
|
|
default=8,
|
|
type=int,
|
|
help="")
|
|
parser.add_argument(
|
|
"--async_worker",
|
|
action='store_true',
|
|
help="")
|
|
parser.add_argument(
|
|
"--bert_config",
|
|
required=True,
|
|
type=str,
|
|
help="location of config.json")
|
|
parser.add_argument(
|
|
"--wandb",
|
|
action='store_true',
|
|
help="use wandb to watch model")
|
|
parser.add_argument(
|
|
"--wandb_project_name",
|
|
default='roberta',
|
|
help="wandb project name")
|
|
parser.add_argument(
|
|
"--log_interval",
|
|
default=100,
|
|
type=int,
|
|
help="report interval")
|
|
parser.add_argument(
|
|
"--log_path",
|
|
type=str,
|
|
required=True,
|
|
help="log file which records train step")
|
|
parser.add_argument(
|
|
"--tensorboard_path",
|
|
type=str,
|
|
required=True,
|
|
help="location of tensorboard file")
|
|
parser.add_argument(
|
|
"--colossal_config",
|
|
type=str,
|
|
required=True,
|
|
help="colossal config, which contains zero config and so on")
|
|
parser.add_argument(
|
|
"--ckpt_path",
|
|
type=str,
|
|
required=True,
|
|
help="location of saving checkpoint, which contains model and optimizer")
|
|
parser.add_argument(
|
|
'--seed',
|
|
type=int,
|
|
default=42,
|
|
help="random seed for initialization")
|
|
parser.add_argument(
|
|
'--vscode_debug',
|
|
action='store_true',
|
|
help="use vscode to debug")
|
|
parser.add_argument(
|
|
'--load_pretrain_model',
|
|
default='',
|
|
type=str,
|
|
help="location of model's checkpoin")
|
|
parser.add_argument(
|
|
'--load_optimizer_lr',
|
|
default='',
|
|
type=str,
|
|
help="location of checkpoint, which contains optimerzier, learning rate, epoch, shard and global_step")
|
|
parser.add_argument(
|
|
'--resume_train',
|
|
action='store_true',
|
|
help="whether resume training from a early checkpoint")
|
|
parser.add_argument(
|
|
'--mlm',
|
|
default='bert',
|
|
type=str,
|
|
help="model type, bert or deberta")
|
|
parser.add_argument(
|
|
'--checkpoint_activations',
|
|
action='store_true',
|
|
help="whether to use gradient checkpointing")
|
|
|
|
args = parser.parse_args()
|
|
return args
|