diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index 1201571..75adab6 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -11,7 +11,6 @@ from enum import Enum from typing import Callable, Dict, Union import torch -from torch.distributed.fsdp import FullStateDictConfig, LocalStateDictConfig from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType @@ -177,6 +176,7 @@ def get_shard_state_dict(shard_model): return shard_states + def load_shard_state_dict(shard_model, shard_state, **kwargs): """ Only used for FSDP module loading. @@ -329,10 +329,9 @@ def load_model_checkpoint(folder, model): # avoid ckpt misuse between FSDP and no-FSDP test_fn = list([f for f in fns if f.startswith("model_t") and not f.endswith(".md5")]).pop() - assert ( - "_zo" in test_fn and gpc.config.parallel.use_fsdp or + assert ("_zo" in test_fn and gpc.config.parallel.use_fsdp) or ( "_zo" not in test_fn and not gpc.config.parallel.use_fsdp - ), f"FSDP model wants to load no-FSDP ckpts or reverse" + ), "FSDP model wants to load no-FSDP ckpts or reverse" max_pp, max_tp, max_zo = 0, 0, 0 for fn in fns: