mirror of https://github.com/InternLM/InternLM
fix load ckpt bug2
parent
5b62a3957a
commit
6b7ca1c6b3
|
@ -11,7 +11,6 @@ from enum import Enum
|
|||
from typing import Callable, Dict, Union
|
||||
|
||||
import torch
|
||||
from torch.distributed.fsdp import FullStateDictConfig, LocalStateDictConfig
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp import StateDictType
|
||||
|
||||
|
@ -177,6 +176,7 @@ def get_shard_state_dict(shard_model):
|
|||
|
||||
return shard_states
|
||||
|
||||
|
||||
def load_shard_state_dict(shard_model, shard_state, **kwargs):
|
||||
"""
|
||||
Only used for FSDP module loading.
|
||||
|
@ -329,10 +329,9 @@ def load_model_checkpoint(folder, model):
|
|||
|
||||
# avoid ckpt misuse between FSDP and no-FSDP
|
||||
test_fn = list([f for f in fns if f.startswith("model_t") and not f.endswith(".md5")]).pop()
|
||||
assert (
|
||||
"_zo" in test_fn and gpc.config.parallel.use_fsdp or
|
||||
assert ("_zo" in test_fn and gpc.config.parallel.use_fsdp) or (
|
||||
"_zo" not in test_fn and not gpc.config.parallel.use_fsdp
|
||||
), f"FSDP model wants to load no-FSDP ckpts or reverse"
|
||||
), "FSDP model wants to load no-FSDP ckpts or reverse"
|
||||
|
||||
max_pp, max_tp, max_zo = 0, 0, 0
|
||||
for fn in fns:
|
||||
|
|
Loading…
Reference in New Issue