diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 6a094e7..e5bd861 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -305,9 +305,12 @@ def args_sanity_check(): gpc.config.parallel.sequence_parallel is True and gpc.config.model.use_flash_attn is False ), "sequence parallel does not support use_flash_attn=False" + if isinstance (gpc.config.parallel["tensor"], int): + gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], mode='origin_tp') + if gpc.config.parallel["tensor"].get("mode", None) is None: gpc.config.parallel["tensor"]["mode"] = "origin_tp" - + if gpc.config.parallel["tensor"].get("mode", None) == "fstp": assert ( gpc.config.parallel.sequence_parallel is True