fix load ckpt bug2

pull/293/head
zaglc 2023-09-25 16:11:50 +08:00
parent 5b62a3957a
commit 6b7ca1c6b3
1 changed files with 3 additions and 4 deletions

View File

@ -11,7 +11,6 @@ from enum import Enum
from typing import Callable, Dict, Union from typing import Callable, Dict, Union
import torch import torch
from torch.distributed.fsdp import FullStateDictConfig, LocalStateDictConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType from torch.distributed.fsdp import StateDictType
@ -177,6 +176,7 @@ def get_shard_state_dict(shard_model):
return shard_states return shard_states
def load_shard_state_dict(shard_model, shard_state, **kwargs): def load_shard_state_dict(shard_model, shard_state, **kwargs):
""" """
Only used for FSDP module loading. Only used for FSDP module loading.
@ -329,10 +329,9 @@ def load_model_checkpoint(folder, model):
# avoid ckpt misuse between FSDP and no-FSDP # avoid ckpt misuse between FSDP and no-FSDP
test_fn = list([f for f in fns if f.startswith("model_t") and not f.endswith(".md5")]).pop() test_fn = list([f for f in fns if f.startswith("model_t") and not f.endswith(".md5")]).pop()
assert ( assert ("_zo" in test_fn and gpc.config.parallel.use_fsdp) or (
"_zo" in test_fn and gpc.config.parallel.use_fsdp or
"_zo" not in test_fn and not gpc.config.parallel.use_fsdp "_zo" not in test_fn and not gpc.config.parallel.use_fsdp
), f"FSDP model wants to load no-FSDP ckpts or reverse" ), "FSDP model wants to load no-FSDP ckpts or reverse"
max_pp, max_tp, max_zo = 0, 0, 0 max_pp, max_tp, max_zo = 0, 0, 0
for fn in fns: for fn in fns: