diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index b9253a56d..2534fa163 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -220,9 +220,9 @@ class GeneralCheckpointIO(CheckpointIO): if strict: remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys)) if len(remain_keys) > 0: - error_msgs = "Missing key(s) in state_dict: {}. ".format( - ", ".join('"{}"'.format(k) for k in missing_keys) - ) + error_msgs = [ + "Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in remain_keys)) + ] raise RuntimeError( "Error(s) in loading state_dict for {}:\n\t{}".format( self.__class__.__name__, "\n\t".join(error_msgs) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 043e5c2b0..3b6917d32 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -381,9 +381,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): remain_keys = remain_keys.union(set(missing_file_keys)) if len(remain_keys) > 0: if strict: - error_msgs = "Missing key(s) in state_dict: {}. ".format( - ", ".join('"{}"'.format(k) for k in missing_keys) - ) + error_msgs = [ + "Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in missing_keys)) + ] raise RuntimeError( "Error(s) in loading state_dict for {}:\n\t{}".format( self.__class__.__name__, "\n\t".join(error_msgs) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 36138f33e..b3917bd9d 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -553,10 +553,10 @@ def load_state_dict_into_model( def load(module: nn.Module, state_dict, prefix="", load_sub_module: bool = True): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) - args = (state_dict, prefix, local_metadata, True, sub_missing_keys, [], error_msgs) + args = (state_dict, prefix, local_metadata, True, sub_missing_keys, unexpected_keys, error_msgs) # Parameters of module and children will start with prefix. We can exit early if there are none in this # state_dict - if len([key for key in state_dict if key.startswith(prefix)]) > 0: + if strict or len([key for key in state_dict if key.startswith(prefix)]) > 0: module._load_from_state_dict(*args) if load_sub_module: for name, child in module._modules.items(): @@ -570,9 +570,9 @@ def load_state_dict_into_model( if strict: if len(unexpected_keys) > 0: - error_msgs = "Unexpected key(s) in state_dict: {}. ".format( - ", ".join('"{}"'.format(k) for k in unexpected_keys) - ) + error_msgs = [ + "Unexpected key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in unexpected_keys)) + ] raise RuntimeError( "Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs)) ) diff --git a/colossalai/inference/core/plugin.py b/colossalai/inference/core/plugin.py index d6a2b8b16..ae526b888 100644 --- a/colossalai/inference/core/plugin.py +++ b/colossalai/inference/core/plugin.py @@ -116,9 +116,9 @@ class InferCheckpoint_io(GeneralCheckpointIO): remain_keys = remain_keys.union(set(missing_file_keys)) if len(remain_keys) > 0: if strict: - error_msgs = "Missing key(s) in state_dict: {}. ".format( - ", ".join('"{}"'.format(k) for k in missing_keys) - ) + error_msgs = [ + "Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in missing_keys)) + ] raise RuntimeError( "Error(s) in loading state_dict for {}:\n\t{}".format( self.__class__.__name__, "\n\t".join(error_msgs)