diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index d2d61b1..276dcad 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -173,6 +173,10 @@ def args_sanity_check(): logger.info("+" * 15 + " beta2_scheduler Info " + "+" * 15) # pylint: disable=W1201 logger.info(f"beta2_scheduler: {gpc.config.beta2_scheduler}") + # process the model config + if "use_flash_attn" not in gpc.config.model: + gpc.config.model._add_item("use_flash_attn", True) + def launch( config: Union[str, Path, Config, Dict],