|
|
|
@ -190,6 +190,7 @@ def main():
|
|
|
|
|
) |
|
|
|
|
parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached") |
|
|
|
|
parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context") |
|
|
|
|
parser.add_argument("--use_fp8_comm", type=bool, default=False, help="for using fp8 during communication") |
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
|
if args.model_type == "bert": |
|
|
|
@ -232,6 +233,7 @@ def main():
|
|
|
|
|
zero_stage=1, |
|
|
|
|
precision="fp16", |
|
|
|
|
initial_scale=1, |
|
|
|
|
fp8_communication=args.use_fp8_comm, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
booster = Booster(plugin=plugin, **booster_kwargs) |
|
|
|
|