fix load ckpt bug2

pull/293/head
zaglc 2023-09-25 16:11:50 +08:00
parent 5b62a3957a
commit 6b7ca1c6b3
1 changed files with 3 additions and 4 deletions

View File

@ -11,7 +11,6 @@ from enum import Enum
from typing import Callable, Dict, Union
import torch
from torch.distributed.fsdp import FullStateDictConfig, LocalStateDictConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
@ -177,6 +176,7 @@ def get_shard_state_dict(shard_model):
return shard_states
def load_shard_state_dict(shard_model, shard_state, **kwargs):
"""
Only used for FSDP module loading.
@ -329,10 +329,9 @@ def load_model_checkpoint(folder, model):
# avoid ckpt misuse between FSDP and no-FSDP
test_fn = list([f for f in fns if f.startswith("model_t") and not f.endswith(".md5")]).pop()
assert (
"_zo" in test_fn and gpc.config.parallel.use_fsdp or
assert ("_zo" in test_fn and gpc.config.parallel.use_fsdp) or (
"_zo" not in test_fn and not gpc.config.parallel.use_fsdp
), f"FSDP model wants to load no-FSDP ckpts or reverse"
), "FSDP model wants to load no-FSDP ckpts or reverse"
max_pp, max_tp, max_zo = 0, 0, 0
for fn in fns: