mirror of https://github.com/hpcaitech/ColossalAI
[colossalai/checkpoint_io/...] fix bug in load_state_dict_into_model; format error msg (#6020)
* fix bug in load_state_dict_into_model; format error msg * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update utils.py to support checking missing_keys * Update general_checkpoint_io.py fix bug in missing_keys error message * retrigger tests --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/6046/head
parent
e96a0761ea
commit
e9032fb0b2
|
@ -220,9 +220,9 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||||
if strict:
|
if strict:
|
||||||
remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))
|
remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))
|
||||||
if len(remain_keys) > 0:
|
if len(remain_keys) > 0:
|
||||||
error_msgs = "Missing key(s) in state_dict: {}. ".format(
|
error_msgs = [
|
||||||
", ".join('"{}"'.format(k) for k in missing_keys)
|
"Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in remain_keys))
|
||||||
)
|
]
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Error(s) in loading state_dict for {}:\n\t{}".format(
|
"Error(s) in loading state_dict for {}:\n\t{}".format(
|
||||||
self.__class__.__name__, "\n\t".join(error_msgs)
|
self.__class__.__name__, "\n\t".join(error_msgs)
|
||||||
|
|
|
@ -381,9 +381,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||||
remain_keys = remain_keys.union(set(missing_file_keys))
|
remain_keys = remain_keys.union(set(missing_file_keys))
|
||||||
if len(remain_keys) > 0:
|
if len(remain_keys) > 0:
|
||||||
if strict:
|
if strict:
|
||||||
error_msgs = "Missing key(s) in state_dict: {}. ".format(
|
error_msgs = [
|
||||||
", ".join('"{}"'.format(k) for k in missing_keys)
|
"Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in missing_keys))
|
||||||
)
|
]
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Error(s) in loading state_dict for {}:\n\t{}".format(
|
"Error(s) in loading state_dict for {}:\n\t{}".format(
|
||||||
self.__class__.__name__, "\n\t".join(error_msgs)
|
self.__class__.__name__, "\n\t".join(error_msgs)
|
||||||
|
|
|
@ -553,10 +553,10 @@ def load_state_dict_into_model(
|
||||||
|
|
||||||
def load(module: nn.Module, state_dict, prefix="", load_sub_module: bool = True):
|
def load(module: nn.Module, state_dict, prefix="", load_sub_module: bool = True):
|
||||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||||
args = (state_dict, prefix, local_metadata, True, sub_missing_keys, [], error_msgs)
|
args = (state_dict, prefix, local_metadata, True, sub_missing_keys, unexpected_keys, error_msgs)
|
||||||
# Parameters of module and children will start with prefix. We can exit early if there are none in this
|
# Parameters of module and children will start with prefix. We can exit early if there are none in this
|
||||||
# state_dict
|
# state_dict
|
||||||
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
|
if strict or len([key for key in state_dict if key.startswith(prefix)]) > 0:
|
||||||
module._load_from_state_dict(*args)
|
module._load_from_state_dict(*args)
|
||||||
if load_sub_module:
|
if load_sub_module:
|
||||||
for name, child in module._modules.items():
|
for name, child in module._modules.items():
|
||||||
|
@ -570,9 +570,9 @@ def load_state_dict_into_model(
|
||||||
|
|
||||||
if strict:
|
if strict:
|
||||||
if len(unexpected_keys) > 0:
|
if len(unexpected_keys) > 0:
|
||||||
error_msgs = "Unexpected key(s) in state_dict: {}. ".format(
|
error_msgs = [
|
||||||
", ".join('"{}"'.format(k) for k in unexpected_keys)
|
"Unexpected key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in unexpected_keys))
|
||||||
)
|
]
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
|
"Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
|
||||||
)
|
)
|
||||||
|
|
|
@ -116,9 +116,9 @@ class InferCheckpoint_io(GeneralCheckpointIO):
|
||||||
remain_keys = remain_keys.union(set(missing_file_keys))
|
remain_keys = remain_keys.union(set(missing_file_keys))
|
||||||
if len(remain_keys) > 0:
|
if len(remain_keys) > 0:
|
||||||
if strict:
|
if strict:
|
||||||
error_msgs = "Missing key(s) in state_dict: {}. ".format(
|
error_msgs = [
|
||||||
", ".join('"{}"'.format(k) for k in missing_keys)
|
"Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in missing_keys))
|
||||||
)
|
]
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Error(s) in loading state_dict for {}:\n\t{}".format(
|
"Error(s) in loading state_dict for {}:\n\t{}".format(
|
||||||
self.__class__.__name__, "\n\t".join(error_msgs)
|
self.__class__.__name__, "\n\t".join(error_msgs)
|
||||||
|
|
Loading…
Reference in New Issue