mirror of https://github.com/hpcaitech/ColossalAI
add fp8_communication flag in the script
parent
e88190184a
commit
66018749f3
|
@ -190,6 +190,7 @@ def main():
|
|||
)
|
||||
parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached")
|
||||
parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context")
|
||||
parser.add_argument("--use_fp8_comm", type=bool, default=False, help="for using fp8 during communication")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.model_type == "bert":
|
||||
|
@ -232,6 +233,7 @@ def main():
|
|||
zero_stage=1,
|
||||
precision="fp16",
|
||||
initial_scale=1,
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
)
|
||||
|
||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||
|
|
|
@ -187,6 +187,7 @@ def main():
|
|||
)
|
||||
parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached")
|
||||
parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context")
|
||||
parser.add_argument("--use_fp8_comm", type=bool, default=False, help="for using fp8 during communication")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.model_type == "gpt2":
|
||||
|
@ -225,6 +226,7 @@ def main():
|
|||
zero_stage=1,
|
||||
precision="fp16",
|
||||
initial_scale=1,
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
)
|
||||
|
||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||
|
|
Loading…
Reference in New Issue