You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/lazy/pretrained.py

329 lines
14 KiB

import copy
import os
from typing import Callable, Optional, Union
import torch
from torch.nn import Module
from colossalai.interface import pretrained as pretrained_interface
class PretrainedManager:
old_from_pretrained: Optional[Callable] = None
@staticmethod
def inject() -> None:
try:
from transformers.modeling_utils import PreTrainedModel
except ImportError:
return
# recover bound method to plain function
PretrainedManager.old_from_pretrained = PreTrainedModel.from_pretrained.__func__
PreTrainedModel.from_pretrained = new_from_pretrained
@staticmethod
def recover() -> None:
try:
from transformers.modeling_utils import PreTrainedModel
except ImportError:
return
# convert plain function to class method
PreTrainedModel.from_pretrained = classmethod(PretrainedManager.old_from_pretrained)
PretrainedManager.old_from_pretrained = None
@classmethod
def new_from_pretrained(
cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs
) -> Module:
from transformers import GenerationConfig
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import (
ContextManagers,
_add_variant,
cached_file,
download_url,
has_file,
is_offline_mode,
is_remote_url,
no_init_weights,
)
from transformers.utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
is_safetensors_available,
logging,
)
logger = logging.get_logger(__name__)
config = kwargs.pop("config", None)
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
_ = kwargs.pop("mirror", None)
from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
_fast_init = kwargs.pop("_fast_init", True)
torch_dtype = kwargs.pop("torch_dtype", None)
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)
variant = kwargs.pop("variant", None)
kwargs.pop("state_dict", None)
kwargs.pop("from_tf", False)
kwargs.pop("from_flax", False)
kwargs.pop("output_loading_info", False)
kwargs.pop("trust_remote_code", None)
kwargs.pop("low_cpu_mem_usage", None)
kwargs.pop("device_map", None)
kwargs.pop("max_memory", None)
kwargs.pop("offload_folder", None)
kwargs.pop("offload_state_dict", False)
kwargs.pop("load_in_8bit", False)
kwargs.pop("load_in_4bit", False)
kwargs.pop("quantization_config", None)
kwargs.pop("adapter_kwargs", {})
kwargs.pop("adapter_name", "default")
kwargs.pop("use_flash_attention_2", False)
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
if len(kwargs) > 0:
logger.warning(f"Below kwargs may be ignored: {list(kwargs.keys())}")
from_pt = True
user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
if from_pipeline is not None:
user_agent["using_pipeline"] = from_pipeline
if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True
# Load config if we don't provide a configuration
if not isinstance(config, PretrainedConfig):
config_path = config if config is not None else pretrained_model_name_or_path
config, model_kwargs = cls.config_class.from_pretrained(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
_from_auto=from_auto_class,
_from_pipeline=from_pipeline,
**kwargs,
)
else:
config = copy.deepcopy(config)
kwarg_attn_imp = kwargs.pop("attn_implementation", None)
if kwarg_attn_imp is not None and config._attn_implementation != kwarg_attn_imp:
config._attn_implementation = kwarg_attn_imp
model_kwargs = kwargs
if commit_hash is None:
commit_hash = getattr(config, "_commit_hash", None)
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# index of the files.
if pretrained_model_name_or_path is not None:
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
is_local = os.path.isdir(pretrained_model_name_or_path)
if is_local:
if use_safetensors is not False and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant))
):
# Load from a safetensors checkpoint
archive_file = os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant)
)
elif use_safetensors is not False and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant))
):
# Load from a sharded safetensors checkpoint
archive_file = os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
)
elif os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant))
):
# Load from a PyTorch checkpoint
archive_file = os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)
)
elif os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant))
):
# Load from a sharded PyTorch checkpoint
archive_file = os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)
)
else:
raise EnvironmentError(
f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory"
f" {pretrained_model_name_or_path}."
)
elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
archive_file = pretrained_model_name_or_path
is_local = True
elif is_remote_url(pretrained_model_name_or_path):
filename = pretrained_model_name_or_path
resolved_archive_file = download_url(pretrained_model_name_or_path)
else:
# set correct filename
if use_safetensors is not False:
filename = _add_variant(SAFE_WEIGHTS_NAME, variant)
else:
filename = _add_variant(WEIGHTS_NAME, variant)
try:
# Load from URL or cache if already cached
cached_file_kwargs = {
"cache_dir": cache_dir,
"force_download": force_download,
"proxies": proxies,
"local_files_only": local_files_only,
"use_auth_token": use_auth_token,
"user_agent": user_agent,
"revision": revision,
"subfolder": subfolder,
"_raise_exceptions_for_missing_entries": False,
"_commit_hash": commit_hash,
}
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
# result when internet is up, the repo and revision exist, but the file does not.
if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant):
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
resolved_archive_file = cached_file(
pretrained_model_name_or_path,
_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant),
**cached_file_kwargs,
)
if resolved_archive_file is not None:
pass
elif use_safetensors:
raise EnvironmentError(
f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} and thus cannot be loaded with `safetensors`. Please make sure that the model has been saved with `safe_serialization=True` or do not set `use_safetensors=True`."
)
else:
# This repo has no safetensors file of any kind, we switch to PyTorch.
filename = _add_variant(WEIGHTS_NAME, variant)
resolved_archive_file = cached_file(
pretrained_model_name_or_path, filename, **cached_file_kwargs
)
if resolved_archive_file is None and filename == _add_variant(WEIGHTS_NAME, variant):
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
resolved_archive_file = cached_file(
pretrained_model_name_or_path,
_add_variant(WEIGHTS_INDEX_NAME, variant),
**cached_file_kwargs,
)
if resolved_archive_file is not None:
pass
if resolved_archive_file is None:
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
# message.
has_file_kwargs = {
"revision": revision,
"proxies": proxies,
"use_auth_token": use_auth_token,
}
if variant is not None and has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named"
f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file without the variant"
f" {variant}. Use `variant=None` to load this model from those weights."
)
else:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named"
f" {_add_variant(WEIGHTS_NAME, variant)}"
)
except EnvironmentError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
# to the original exception.
raise
except Exception:
# For any other exception, we throw a generic error.
raise EnvironmentError(
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
" from 'https://huggingface.co/models', make sure you don't have a local directory with the"
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
f" directory containing a file named {_add_variant(WEIGHTS_NAME, variant)}."
)
if is_local:
logger.info(f"loading weights file {archive_file}")
resolved_archive_file = archive_file
else:
logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
else:
resolved_archive_file = None
if from_pt:
# set dtype to instantiate the model under:
# 1. If torch_dtype is not None, we use that dtype
dtype_orig = None
if torch_dtype is not None:
if not isinstance(torch_dtype, torch.dtype):
raise ValueError(f"`torch_dtype` can be either `torch.dtype` or `None`, but received {torch_dtype}")
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
config.name_or_path = pretrained_model_name_or_path
# Instantiate model.
init_contexts = [no_init_weights(_enable=_fast_init)]
with ContextManagers(init_contexts):
model = cls(config, *model_args, **model_kwargs)
if from_pt:
# restore default dtype
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)
# make sure token embedding weights are still tied if needed
model.tie_weights()
# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()
# If it is a model with generation capabilities, attempt to load the generation config
if model.can_generate():
try:
model.generation_config = GenerationConfig.from_pretrained(
pretrained_model_name_or_path,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
_from_auto=from_auto_class,
_from_pipeline=from_pipeline,
**kwargs,
)
except (OSError, TypeError):
logger.info("Generation config file not found, using a generation config created from the model config.")
# set pretrained path
if resolved_archive_file:
pretrained_interface.set_pretrained_path(model, resolved_archive_file)
return model