Merge branch 'feature_add_fsdp3' of https://github.com/zaglc/InternLM into feature_add_fsdp3

pull/293/head
huangting4201 2023-10-08 17:16:06 +08:00
commit 610e011133
2 changed files with 9 additions and 6 deletions

View File

@ -74,15 +74,14 @@ def args_sanity_check():
gpc.config.parallel._add_item("tensor", 1)
if isinstance(gpc.config.parallel.pipeline, int):
pp = gpc.config.parallel.pipelines
pp = gpc.config.parallel.pipeline
else:
pp = gpc.config.parallel.pipeline.size
if "use_fsdp" not in gpc.config.parallel:
gpc.config.parallel._add_item("use_fsdp", False)
elif gpc.config.parallel.use_fsdp and pp > 1:
logger.warning("FSDP not support when pipeline parallel is enabled, auto-close FSDP")
gpc.config.parallel._add_item("use_fsdp", False)
assert not (gpc.config.parallel.use_fsdp and pp > 1), "FSDP not support when pipeline size > 1, please set pipeline size to 1 or close FSDP"
# processing the data config in gpc
data = gpc.config.data
@ -282,6 +281,9 @@ def args_sanity_check():
model._add_item("moe_use_residual", False)
if "moe_gate_k" not in model:
model._add_item("moe_gate_k", 2)
assert not (
gpc.config.model.num_experts > 1 and gpc.config.parallel.use_fsdp
), "FSDP does not support num_experts > 1"
# process the parallel config
if "sequence_parallel" not in gpc.config.parallel:

View File

@ -301,14 +301,15 @@ def main(args):
if __name__ == "__main__":
assert torch.__version__ >= "2.0.1", f"requires torch>=2.0.1 but current version is {torch.__version__}"
args = parse_args()
hostname = socket.gethostname()
# initialize distributed environment
initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
assert hasattr(gpc, "config") and gpc.config is not None
if gpc.config.parallel.use_fsdp:
assert torch.__version__ >= "2.0.1", f"requires torch>=2.0.1 but current version is {torch.__version__}"
# initialize monitor manager context
with initialize_monitor_manager(