Browse Source

add fp8_communication flag in the script

pull/5885/head
BurkeHulk 4 months ago
parent
commit
66018749f3
  1. 2
      examples/language/bert/finetune.py
  2. 2
      examples/language/gpt/hybridparallelism/finetune.py

2
examples/language/bert/finetune.py

@ -190,6 +190,7 @@ def main():
)
parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached")
parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context")
parser.add_argument("--use_fp8_comm", type=bool, default=False, help="for using fp8 during communication")
args = parser.parse_args()
if args.model_type == "bert":
@ -232,6 +233,7 @@ def main():
zero_stage=1,
precision="fp16",
initial_scale=1,
fp8_communication=args.use_fp8_comm,
)
booster = Booster(plugin=plugin, **booster_kwargs)

2
examples/language/gpt/hybridparallelism/finetune.py

@ -187,6 +187,7 @@ def main():
)
parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached")
parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context")
parser.add_argument("--use_fp8_comm", type=bool, default=False, help="for using fp8 during communication")
args = parser.parse_args()
if args.model_type == "gpt2":
@ -225,6 +226,7 @@ def main():
zero_stage=1,
precision="fp16",
initial_scale=1,
fp8_communication=args.use_fp8_comm,
)
booster = Booster(plugin=plugin, **booster_kwargs)

Loading…
Cancel
Save