fix the ci incompatible in config

pull/407/head
yingtongxiong 2023-10-09 21:33:26 +08:00
parent 007e58a4af
commit a8dea6313f
1 changed files with 4 additions and 1 deletions

View File

@ -305,9 +305,12 @@ def args_sanity_check():
gpc.config.parallel.sequence_parallel is True and gpc.config.model.use_flash_attn is False
), "sequence parallel does not support use_flash_attn=False"
if isinstance (gpc.config.parallel["tensor"], int):
gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], mode='origin_tp')
if gpc.config.parallel["tensor"].get("mode", None) is None:
gpc.config.parallel["tensor"]["mode"] = "origin_tp"
if gpc.config.parallel["tensor"].get("mode", None) == "fstp":
assert (
gpc.config.parallel.sequence_parallel is True