diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index e1800f29b..47f0ba0af 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -558,6 +558,10 @@ def load_state_dict_into_model( missing_keys = missing_keys.append(sub_missing_keys) if strict: + if len(error_msgs) > 0: + raise RuntimeError( + "Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs)) + ) if len(unexpected_keys) > 0: error_msgs = "Unexpected key(s) in state_dict: {}. ".format( ", ".join('"{}"'.format(k) for k in unexpected_keys)