diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 64577c3..0c00bfd 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -74,15 +74,14 @@ def args_sanity_check(): gpc.config.parallel._add_item("tensor", 1) if isinstance(gpc.config.parallel.pipeline, int): - pp = gpc.config.parallel.pipelines + pp = gpc.config.parallel.pipeline else: pp = gpc.config.parallel.pipeline.size if "use_fsdp" not in gpc.config.parallel: gpc.config.parallel._add_item("use_fsdp", False) - elif gpc.config.parallel.use_fsdp and pp > 1: - logger.warning("FSDP not support when pipeline parallel is enabled, auto-close FSDP") - gpc.config.parallel._add_item("use_fsdp", False) + + assert not (gpc.config.parallel.use_fsdp and pp > 1), "FSDP not support when pipeline size > 1, please set pipeline size to 1 or close FSDP" # processing the data config in gpc data = gpc.config.data @@ -282,6 +281,9 @@ def args_sanity_check(): model._add_item("moe_use_residual", False) if "moe_gate_k" not in model: model._add_item("moe_gate_k", 2) + assert not ( + gpc.config.model.num_experts > 1 and gpc.config.parallel.use_fsdp + ), "FSDP does not support num_experts > 1" # process the parallel config if "sequence_parallel" not in gpc.config.parallel: diff --git a/train.py b/train.py index 4c04818..71ce548 100644 --- a/train.py +++ b/train.py @@ -301,14 +301,15 @@ def main(args): if __name__ == "__main__": - assert torch.__version__ >= "2.0.1", f"requires torch>=2.0.1 but current version is {torch.__version__}" - args = parse_args() hostname = socket.gethostname() # initialize distributed environment initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed) assert hasattr(gpc, "config") and gpc.config is not None + if gpc.config.parallel.use_fsdp: + assert torch.__version__ >= "2.0.1", f"requires torch>=2.0.1 but current version is {torch.__version__}" + # initialize monitor manager context with initialize_monitor_manager(