mirror of https://github.com/InternLM/InternLM
modify args_sanity_check for fsdp with pipeline and fsdp with moe
parent
eb14dae005
commit
132a841d42
|
@ -74,15 +74,14 @@ def args_sanity_check():
|
||||||
gpc.config.parallel._add_item("tensor", 1)
|
gpc.config.parallel._add_item("tensor", 1)
|
||||||
|
|
||||||
if isinstance(gpc.config.parallel.pipeline, int):
|
if isinstance(gpc.config.parallel.pipeline, int):
|
||||||
pp = gpc.config.parallel.pipelines
|
pp = gpc.config.parallel.pipeline
|
||||||
else:
|
else:
|
||||||
pp = gpc.config.parallel.pipeline.size
|
pp = gpc.config.parallel.pipeline.size
|
||||||
|
|
||||||
if "use_fsdp" not in gpc.config.parallel:
|
if "use_fsdp" not in gpc.config.parallel:
|
||||||
gpc.config.parallel._add_item("use_fsdp", False)
|
gpc.config.parallel._add_item("use_fsdp", False)
|
||||||
elif gpc.config.parallel.use_fsdp and pp > 1:
|
|
||||||
logger.warning("FSDP not support when pipeline parallel is enabled, auto-close FSDP")
|
assert not (gpc.config.parallel.use_fsdp and pp > 1), "FSDP not support when pipeline size > 1, please set pipeline size to 1 or close FSDP"
|
||||||
gpc.config.parallel._add_item("use_fsdp", False)
|
|
||||||
|
|
||||||
# processing the data config in gpc
|
# processing the data config in gpc
|
||||||
data = gpc.config.data
|
data = gpc.config.data
|
||||||
|
@ -282,6 +281,9 @@ def args_sanity_check():
|
||||||
model._add_item("moe_use_residual", False)
|
model._add_item("moe_use_residual", False)
|
||||||
if "moe_gate_k" not in model:
|
if "moe_gate_k" not in model:
|
||||||
model._add_item("moe_gate_k", 2)
|
model._add_item("moe_gate_k", 2)
|
||||||
|
assert not (
|
||||||
|
gpc.config.model.num_experts > 1 and gpc.config.parallel.use_fsdp
|
||||||
|
), "FSDP does not support num_experts > 1"
|
||||||
|
|
||||||
# process the parallel config
|
# process the parallel config
|
||||||
if "sequence_parallel" not in gpc.config.parallel:
|
if "sequence_parallel" not in gpc.config.parallel:
|
||||||
|
|
5
train.py
5
train.py
|
@ -301,14 +301,15 @@ def main(args):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
assert torch.__version__ >= "2.0.1", f"requires torch>=2.0.1 but current version is {torch.__version__}"
|
|
||||||
|
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
hostname = socket.gethostname()
|
hostname = socket.gethostname()
|
||||||
|
|
||||||
# initialize distributed environment
|
# initialize distributed environment
|
||||||
initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
|
initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
|
||||||
assert hasattr(gpc, "config") and gpc.config is not None
|
assert hasattr(gpc, "config") and gpc.config is not None
|
||||||
|
if gpc.config.parallel.use_fsdp:
|
||||||
|
assert torch.__version__ >= "2.0.1", f"requires torch>=2.0.1 but current version is {torch.__version__}"
|
||||||
|
|
||||||
|
|
||||||
# initialize monitor manager context
|
# initialize monitor manager context
|
||||||
with initialize_monitor_manager(
|
with initialize_monitor_manager(
|
||||||
|
|
Loading…
Reference in New Issue