[colossalai/checkpoint_io/...] fix bug in load_state_dict_into_model; format error msg (#6020)

* fix bug in load_state_dict_into_model; format error msg

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update utils.py

to support checking missing_keys

* Update general_checkpoint_io.py

fix bug in missing_keys error message

* retrigger tests

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/6046/head
Gao, Ruiyuan 2024-09-02 16:56:35 +08:00 committed by GitHub
parent e96a0761ea
commit e9032fb0b2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 14 additions and 14 deletions

View File

@ -220,9 +220,9 @@ class GeneralCheckpointIO(CheckpointIO):
if strict: if strict:
remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys)) remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))
if len(remain_keys) > 0: if len(remain_keys) > 0:
error_msgs = "Missing key(s) in state_dict: {}. ".format( error_msgs = [
", ".join('"{}"'.format(k) for k in missing_keys) "Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in remain_keys))
) ]
raise RuntimeError( raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format( "Error(s) in loading state_dict for {}:\n\t{}".format(
self.__class__.__name__, "\n\t".join(error_msgs) self.__class__.__name__, "\n\t".join(error_msgs)

View File

@ -381,9 +381,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
remain_keys = remain_keys.union(set(missing_file_keys)) remain_keys = remain_keys.union(set(missing_file_keys))
if len(remain_keys) > 0: if len(remain_keys) > 0:
if strict: if strict:
error_msgs = "Missing key(s) in state_dict: {}. ".format( error_msgs = [
", ".join('"{}"'.format(k) for k in missing_keys) "Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in missing_keys))
) ]
raise RuntimeError( raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format( "Error(s) in loading state_dict for {}:\n\t{}".format(
self.__class__.__name__, "\n\t".join(error_msgs) self.__class__.__name__, "\n\t".join(error_msgs)

View File

@ -553,10 +553,10 @@ def load_state_dict_into_model(
def load(module: nn.Module, state_dict, prefix="", load_sub_module: bool = True): def load(module: nn.Module, state_dict, prefix="", load_sub_module: bool = True):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, sub_missing_keys, [], error_msgs) args = (state_dict, prefix, local_metadata, True, sub_missing_keys, unexpected_keys, error_msgs)
# Parameters of module and children will start with prefix. We can exit early if there are none in this # Parameters of module and children will start with prefix. We can exit early if there are none in this
# state_dict # state_dict
if len([key for key in state_dict if key.startswith(prefix)]) > 0: if strict or len([key for key in state_dict if key.startswith(prefix)]) > 0:
module._load_from_state_dict(*args) module._load_from_state_dict(*args)
if load_sub_module: if load_sub_module:
for name, child in module._modules.items(): for name, child in module._modules.items():
@ -570,9 +570,9 @@ def load_state_dict_into_model(
if strict: if strict:
if len(unexpected_keys) > 0: if len(unexpected_keys) > 0:
error_msgs = "Unexpected key(s) in state_dict: {}. ".format( error_msgs = [
", ".join('"{}"'.format(k) for k in unexpected_keys) "Unexpected key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in unexpected_keys))
) ]
raise RuntimeError( raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs)) "Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
) )

View File

@ -116,9 +116,9 @@ class InferCheckpoint_io(GeneralCheckpointIO):
remain_keys = remain_keys.union(set(missing_file_keys)) remain_keys = remain_keys.union(set(missing_file_keys))
if len(remain_keys) > 0: if len(remain_keys) > 0:
if strict: if strict:
error_msgs = "Missing key(s) in state_dict: {}. ".format( error_msgs = [
", ".join('"{}"'.format(k) for k in missing_keys) "Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in missing_keys))
) ]
raise RuntimeError( raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format( "Error(s) in loading state_dict for {}:\n\t{}".format(
self.__class__.__name__, "\n\t".join(error_msgs) self.__class__.__name__, "\n\t".join(error_msgs)