import copy
import os
from typing import Callable , Optional , Union
import torch
from torch . nn import Module
from colossalai . interface import pretrained as pretrained_interface
class PretrainedManager :
old_from_pretrained : Optional [ Callable ] = None
@staticmethod
def inject ( ) - > None :
try :
from transformers . modeling_utils import PreTrainedModel
except ImportError :
return
# recover bound method to plain function
PretrainedManager . old_from_pretrained = PreTrainedModel . from_pretrained . __func__
PreTrainedModel . from_pretrained = new_from_pretrained
@staticmethod
def recover ( ) - > None :
try :
from transformers . modeling_utils import PreTrainedModel
except ImportError :
return
# convert plain function to class method
PreTrainedModel . from_pretrained = classmethod ( PretrainedManager . old_from_pretrained )
PretrainedManager . old_from_pretrained = None
@classmethod
def new_from_pretrained (
cls , pretrained_model_name_or_path : Optional [ Union [ str , os . PathLike ] ] , * model_args , * * kwargs
) - > Module :
from transformers import GenerationConfig
from transformers . configuration_utils import PretrainedConfig
from transformers . modeling_utils import (
ContextManagers ,
_add_variant ,
cached_file ,
download_url ,
has_file ,
is_offline_mode ,
is_remote_url ,
no_init_weights ,
)
from transformers . utils import (
SAFE_WEIGHTS_INDEX_NAME ,
SAFE_WEIGHTS_NAME ,
WEIGHTS_INDEX_NAME ,
WEIGHTS_NAME ,
is_safetensors_available ,
logging ,
)
logger = logging . get_logger ( __name__ )
config = kwargs . pop ( " config " , None )
cache_dir = kwargs . pop ( " cache_dir " , None )
force_download = kwargs . pop ( " force_download " , False )
proxies = kwargs . pop ( " proxies " , None )
local_files_only = kwargs . pop ( " local_files_only " , False )
use_auth_token = kwargs . pop ( " use_auth_token " , None )
revision = kwargs . pop ( " revision " , None )
_ = kwargs . pop ( " mirror " , None )
from_pipeline = kwargs . pop ( " _from_pipeline " , None )
from_auto_class = kwargs . pop ( " _from_auto " , False )
_fast_init = kwargs . pop ( " _fast_init " , True )
torch_dtype = kwargs . pop ( " torch_dtype " , None )
subfolder = kwargs . pop ( " subfolder " , " " )
commit_hash = kwargs . pop ( " _commit_hash " , None )
variant = kwargs . pop ( " variant " , None )
kwargs . pop ( " state_dict " , None )
kwargs . pop ( " from_tf " , False )
kwargs . pop ( " from_flax " , False )
kwargs . pop ( " output_loading_info " , False )
kwargs . pop ( " trust_remote_code " , None )
kwargs . pop ( " low_cpu_mem_usage " , None )
kwargs . pop ( " device_map " , None )
kwargs . pop ( " max_memory " , None )
kwargs . pop ( " offload_folder " , None )
kwargs . pop ( " offload_state_dict " , False )
kwargs . pop ( " load_in_8bit " , False )
kwargs . pop ( " load_in_4bit " , False )
kwargs . pop ( " quantization_config " , None )
kwargs . pop ( " adapter_kwargs " , { } )
kwargs . pop ( " adapter_name " , " default " )
kwargs . pop ( " use_flash_attention_2 " , False )
use_safetensors = kwargs . pop ( " use_safetensors " , None if is_safetensors_available ( ) else False )
if len ( kwargs ) > 0 :
logger . warning ( f " Below kwargs may be ignored: { list ( kwargs . keys ( ) ) } " )
from_pt = True
user_agent = { " file_type " : " model " , " framework " : " pytorch " , " from_auto_class " : from_auto_class }
if from_pipeline is not None :
user_agent [ " using_pipeline " ] = from_pipeline
if is_offline_mode ( ) and not local_files_only :
logger . info ( " Offline mode: forcing local_files_only=True " )
local_files_only = True
# Load config if we don't provide a configuration
if not isinstance ( config , PretrainedConfig ) :
config_path = config if config is not None else pretrained_model_name_or_path
config , model_kwargs = cls . config_class . from_pretrained (
config_path ,
cache_dir = cache_dir ,
return_unused_kwargs = True ,
force_download = force_download ,
proxies = proxies ,
local_files_only = local_files_only ,
use_auth_token = use_auth_token ,
revision = revision ,
subfolder = subfolder ,
_from_auto = from_auto_class ,
_from_pipeline = from_pipeline ,
* * kwargs ,
)
else :
config = copy . deepcopy ( config )
kwarg_attn_imp = kwargs . pop ( " attn_implementation " , None )
if kwarg_attn_imp is not None and config . _attn_implementation != kwarg_attn_imp :
config . _attn_implementation = kwarg_attn_imp
model_kwargs = kwargs
if commit_hash is None :
commit_hash = getattr ( config , " _commit_hash " , None )
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# index of the files.
if pretrained_model_name_or_path is not None :
pretrained_model_name_or_path = str ( pretrained_model_name_or_path )
is_local = os . path . isdir ( pretrained_model_name_or_path )
if is_local :
if use_safetensors is not False and os . path . isfile (
os . path . join ( pretrained_model_name_or_path , subfolder , _add_variant ( SAFE_WEIGHTS_NAME , variant ) )
) :
# Load from a safetensors checkpoint
archive_file = os . path . join (
pretrained_model_name_or_path , subfolder , _add_variant ( SAFE_WEIGHTS_NAME , variant )
)
elif use_safetensors is not False and os . path . isfile (
os . path . join ( pretrained_model_name_or_path , subfolder , _add_variant ( SAFE_WEIGHTS_INDEX_NAME , variant ) )
) :
# Load from a sharded safetensors checkpoint
archive_file = os . path . join (
pretrained_model_name_or_path , subfolder , _add_variant ( SAFE_WEIGHTS_INDEX_NAME , variant )
)
elif os . path . isfile (
os . path . join ( pretrained_model_name_or_path , subfolder , _add_variant ( WEIGHTS_NAME , variant ) )
) :
# Load from a PyTorch checkpoint
archive_file = os . path . join (
pretrained_model_name_or_path , subfolder , _add_variant ( WEIGHTS_NAME , variant )
)
elif os . path . isfile (
os . path . join ( pretrained_model_name_or_path , subfolder , _add_variant ( WEIGHTS_INDEX_NAME , variant ) )
) :
# Load from a sharded PyTorch checkpoint
archive_file = os . path . join (
pretrained_model_name_or_path , subfolder , _add_variant ( WEIGHTS_INDEX_NAME , variant )
)
else :
raise EnvironmentError (
f " Error no file named { _add_variant ( WEIGHTS_NAME , variant ) } found in directory "
f " { pretrained_model_name_or_path } . "
)
elif os . path . isfile ( os . path . join ( subfolder , pretrained_model_name_or_path ) ) :
archive_file = pretrained_model_name_or_path
is_local = True
elif is_remote_url ( pretrained_model_name_or_path ) :
filename = pretrained_model_name_or_path
resolved_archive_file = download_url ( pretrained_model_name_or_path )
else :
# set correct filename
if use_safetensors is not False :
filename = _add_variant ( SAFE_WEIGHTS_NAME , variant )
else :
filename = _add_variant ( WEIGHTS_NAME , variant )
try :
# Load from URL or cache if already cached
cached_file_kwargs = {
" cache_dir " : cache_dir ,
" force_download " : force_download ,
" proxies " : proxies ,
" local_files_only " : local_files_only ,
" use_auth_token " : use_auth_token ,
" user_agent " : user_agent ,
" revision " : revision ,
" subfolder " : subfolder ,
" _raise_exceptions_for_missing_entries " : False ,
" _commit_hash " : commit_hash ,
}
resolved_archive_file = cached_file ( pretrained_model_name_or_path , filename , * * cached_file_kwargs )
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
# result when internet is up, the repo and revision exist, but the file does not.
if resolved_archive_file is None and filename == _add_variant ( SAFE_WEIGHTS_NAME , variant ) :
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
resolved_archive_file = cached_file (
pretrained_model_name_or_path ,
_add_variant ( SAFE_WEIGHTS_INDEX_NAME , variant ) ,
* * cached_file_kwargs ,
)
if resolved_archive_file is not None :
pass
elif use_safetensors :
raise EnvironmentError (
f " { _add_variant ( SAFE_WEIGHTS_NAME , variant ) } or { _add_variant ( SAFE_WEIGHTS_INDEX_NAME , variant ) } and thus cannot be loaded with `safetensors`. Please make sure that the model has been saved with `safe_serialization=True` or do not set `use_safetensors=True`. "
)
else :
# This repo has no safetensors file of any kind, we switch to PyTorch.
filename = _add_variant ( WEIGHTS_NAME , variant )
resolved_archive_file = cached_file (
pretrained_model_name_or_path , filename , * * cached_file_kwargs
)
if resolved_archive_file is None and filename == _add_variant ( WEIGHTS_NAME , variant ) :
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
resolved_archive_file = cached_file (
pretrained_model_name_or_path ,
_add_variant ( WEIGHTS_INDEX_NAME , variant ) ,
* * cached_file_kwargs ,
)
if resolved_archive_file is not None :
pass
if resolved_archive_file is None :
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
# message.
has_file_kwargs = {
" revision " : revision ,
" proxies " : proxies ,
" use_auth_token " : use_auth_token ,
}
if variant is not None and has_file ( pretrained_model_name_or_path , WEIGHTS_NAME , * * has_file_kwargs ) :
raise EnvironmentError (
f " { pretrained_model_name_or_path } does not appear to have a file named "
f " { _add_variant ( WEIGHTS_NAME , variant ) } but there is a file without the variant "
f " { variant } . Use `variant=None` to load this model from those weights. "
)
else :
raise EnvironmentError (
f " { pretrained_model_name_or_path } does not appear to have a file named "
f " { _add_variant ( WEIGHTS_NAME , variant ) } "
)
except EnvironmentError :
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
# to the original exception.
raise
except Exception :
# For any other exception, we throw a generic error.
raise EnvironmentError (
f " Can ' t load the model for ' { pretrained_model_name_or_path } ' . If you were trying to load it "
" from ' https://huggingface.co/models ' , make sure you don ' t have a local directory with the "
f " same name. Otherwise, make sure ' { pretrained_model_name_or_path } ' is the correct path to a "
f " directory containing a file named { _add_variant ( WEIGHTS_NAME , variant ) } . "
)
if is_local :
logger . info ( f " loading weights file { archive_file } " )
resolved_archive_file = archive_file
else :
logger . info ( f " loading weights file { filename } from cache at { resolved_archive_file } " )
else :
resolved_archive_file = None
if from_pt :
# set dtype to instantiate the model under:
# 1. If torch_dtype is not None, we use that dtype
dtype_orig = None
if torch_dtype is not None :
if not isinstance ( torch_dtype , torch . dtype ) :
raise ValueError ( f " `torch_dtype` can be either `torch.dtype` or `None`, but received { torch_dtype } " )
dtype_orig = cls . _set_default_torch_dtype ( torch_dtype )
config . name_or_path = pretrained_model_name_or_path
# Instantiate model.
init_contexts = [ no_init_weights ( _enable = _fast_init ) ]
with ContextManagers ( init_contexts ) :
model = cls ( config , * model_args , * * model_kwargs )
if from_pt :
# restore default dtype
if dtype_orig is not None :
torch . set_default_dtype ( dtype_orig )
# make sure token embedding weights are still tied if needed
model . tie_weights ( )
# Set model in evaluation mode to deactivate DropOut modules by default
model . eval ( )
# If it is a model with generation capabilities, attempt to load the generation config
if model . can_generate ( ) :
try :
model . generation_config = GenerationConfig . from_pretrained (
pretrained_model_name_or_path ,
cache_dir = cache_dir ,
force_download = force_download ,
proxies = proxies ,
local_files_only = local_files_only ,
use_auth_token = use_auth_token ,
revision = revision ,
subfolder = subfolder ,
_from_auto = from_auto_class ,
_from_pipeline = from_pipeline ,
* * kwargs ,
)
except ( OSError , TypeError ) :
logger . info ( " Generation config file not found, using a generation config created from the model config. " )
# set pretrained path
if resolved_archive_file :
pretrained_interface . set_pretrained_path ( model , resolved_archive_file )
return model