mirror of https://github.com/InternLM/InternLM
Merge branch 'feature_add_fsdp3' of https://github.com/zaglc/InternLM into feature_add_fsdp3
commit
610e011133
|
@ -74,15 +74,14 @@ def args_sanity_check():
|
|||
gpc.config.parallel._add_item("tensor", 1)
|
||||
|
||||
if isinstance(gpc.config.parallel.pipeline, int):
|
||||
pp = gpc.config.parallel.pipelines
|
||||
pp = gpc.config.parallel.pipeline
|
||||
else:
|
||||
pp = gpc.config.parallel.pipeline.size
|
||||
|
||||
if "use_fsdp" not in gpc.config.parallel:
|
||||
gpc.config.parallel._add_item("use_fsdp", False)
|
||||
elif gpc.config.parallel.use_fsdp and pp > 1:
|
||||
logger.warning("FSDP not support when pipeline parallel is enabled, auto-close FSDP")
|
||||
gpc.config.parallel._add_item("use_fsdp", False)
|
||||
|
||||
assert not (gpc.config.parallel.use_fsdp and pp > 1), "FSDP not support when pipeline size > 1, please set pipeline size to 1 or close FSDP"
|
||||
|
||||
# processing the data config in gpc
|
||||
data = gpc.config.data
|
||||
|
@ -282,6 +281,9 @@ def args_sanity_check():
|
|||
model._add_item("moe_use_residual", False)
|
||||
if "moe_gate_k" not in model:
|
||||
model._add_item("moe_gate_k", 2)
|
||||
assert not (
|
||||
gpc.config.model.num_experts > 1 and gpc.config.parallel.use_fsdp
|
||||
), "FSDP does not support num_experts > 1"
|
||||
|
||||
# process the parallel config
|
||||
if "sequence_parallel" not in gpc.config.parallel:
|
||||
|
|
5
train.py
5
train.py
|
@ -301,14 +301,15 @@ def main(args):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
assert torch.__version__ >= "2.0.1", f"requires torch>=2.0.1 but current version is {torch.__version__}"
|
||||
|
||||
args = parse_args()
|
||||
hostname = socket.gethostname()
|
||||
|
||||
# initialize distributed environment
|
||||
initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
|
||||
assert hasattr(gpc, "config") and gpc.config is not None
|
||||
if gpc.config.parallel.use_fsdp:
|
||||
assert torch.__version__ >= "2.0.1", f"requires torch>=2.0.1 but current version is {torch.__version__}"
|
||||
|
||||
|
||||
# initialize monitor manager context
|
||||
with initialize_monitor_manager(
|
||||
|
|
Loading…
Reference in New Issue