add fp8_communication flag in the script

pull/5885/head
BurkeHulk 2024-07-12 15:26:17 +08:00
parent e88190184a
commit 66018749f3
2 changed files with 4 additions and 0 deletions

View File

@ -190,6 +190,7 @@ def main():
)
parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached")
parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context")
parser.add_argument("--use_fp8_comm", type=bool, default=False, help="for using fp8 during communication")
args = parser.parse_args()
if args.model_type == "bert":
@ -232,6 +233,7 @@ def main():
zero_stage=1,
precision="fp16",
initial_scale=1,
fp8_communication=args.use_fp8_comm,
)
booster = Booster(plugin=plugin, **booster_kwargs)

View File

@ -187,6 +187,7 @@ def main():
)
parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached")
parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context")
parser.add_argument("--use_fp8_comm", type=bool, default=False, help="for using fp8 during communication")
args = parser.parse_args()
if args.model_type == "gpt2":
@ -225,6 +226,7 @@ def main():
zero_stage=1,
precision="fp16",
initial_scale=1,
fp8_communication=args.use_fp8_comm,
)
booster = Booster(plugin=plugin, **booster_kwargs)