|
|
|
@ -1,3 +1,4 @@
|
|
|
|
|
import copy |
|
|
|
|
import os |
|
|
|
|
from typing import Callable, Optional, Union |
|
|
|
|
|
|
|
|
@ -74,6 +75,24 @@ def new_from_pretrained(
|
|
|
|
|
subfolder = kwargs.pop("subfolder", "") |
|
|
|
|
commit_hash = kwargs.pop("_commit_hash", None) |
|
|
|
|
variant = kwargs.pop("variant", None) |
|
|
|
|
|
|
|
|
|
kwargs.pop("state_dict", None) |
|
|
|
|
kwargs.pop("from_tf", False) |
|
|
|
|
kwargs.pop("from_flax", False) |
|
|
|
|
kwargs.pop("output_loading_info", False) |
|
|
|
|
kwargs.pop("trust_remote_code", None) |
|
|
|
|
kwargs.pop("low_cpu_mem_usage", None) |
|
|
|
|
kwargs.pop("device_map", None) |
|
|
|
|
kwargs.pop("max_memory", None) |
|
|
|
|
kwargs.pop("offload_folder", None) |
|
|
|
|
kwargs.pop("offload_state_dict", False) |
|
|
|
|
kwargs.pop("load_in_8bit", False) |
|
|
|
|
kwargs.pop("load_in_4bit", False) |
|
|
|
|
kwargs.pop("quantization_config", None) |
|
|
|
|
kwargs.pop("adapter_kwargs", {}) |
|
|
|
|
kwargs.pop("adapter_name", "default") |
|
|
|
|
kwargs.pop("use_flash_attention_2", False) |
|
|
|
|
|
|
|
|
|
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False) |
|
|
|
|
|
|
|
|
|
if len(kwargs) > 0: |
|
|
|
@ -108,6 +127,10 @@ def new_from_pretrained(
|
|
|
|
|
**kwargs, |
|
|
|
|
) |
|
|
|
|
else: |
|
|
|
|
config = copy.deepcopy(config) |
|
|
|
|
kwarg_attn_imp = kwargs.pop("attn_implementation", None) |
|
|
|
|
if kwarg_attn_imp is not None and config._attn_implementation != kwarg_attn_imp: |
|
|
|
|
config._attn_implementation = kwarg_attn_imp |
|
|
|
|
model_kwargs = kwargs |
|
|
|
|
|
|
|
|
|
if commit_hash is None: |
|
|
|
|