mirror of https://github.com/InternLM/InternLM
fix load ckpt bug2
parent
5b62a3957a
commit
6b7ca1c6b3
|
@ -11,7 +11,6 @@ from enum import Enum
|
||||||
from typing import Callable, Dict, Union
|
from typing import Callable, Dict, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed.fsdp import FullStateDictConfig, LocalStateDictConfig
|
|
||||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||||
from torch.distributed.fsdp import StateDictType
|
from torch.distributed.fsdp import StateDictType
|
||||||
|
|
||||||
|
@ -177,6 +176,7 @@ def get_shard_state_dict(shard_model):
|
||||||
|
|
||||||
return shard_states
|
return shard_states
|
||||||
|
|
||||||
|
|
||||||
def load_shard_state_dict(shard_model, shard_state, **kwargs):
|
def load_shard_state_dict(shard_model, shard_state, **kwargs):
|
||||||
"""
|
"""
|
||||||
Only used for FSDP module loading.
|
Only used for FSDP module loading.
|
||||||
|
@ -329,10 +329,9 @@ def load_model_checkpoint(folder, model):
|
||||||
|
|
||||||
# avoid ckpt misuse between FSDP and no-FSDP
|
# avoid ckpt misuse between FSDP and no-FSDP
|
||||||
test_fn = list([f for f in fns if f.startswith("model_t") and not f.endswith(".md5")]).pop()
|
test_fn = list([f for f in fns if f.startswith("model_t") and not f.endswith(".md5")]).pop()
|
||||||
assert (
|
assert ("_zo" in test_fn and gpc.config.parallel.use_fsdp) or (
|
||||||
"_zo" in test_fn and gpc.config.parallel.use_fsdp or
|
|
||||||
"_zo" not in test_fn and not gpc.config.parallel.use_fsdp
|
"_zo" not in test_fn and not gpc.config.parallel.use_fsdp
|
||||||
), f"FSDP model wants to load no-FSDP ckpts or reverse"
|
), "FSDP model wants to load no-FSDP ckpts or reverse"
|
||||||
|
|
||||||
max_pp, max_tp, max_zo = 0, 0, 0
|
max_pp, max_tp, max_zo = 0, 0, 0
|
||||||
for fn in fns:
|
for fn in fns:
|
||||||
|
|
Loading…
Reference in New Issue