diff --git a/utils.py b/utils.py index 0ebc761..2bf13c2 100644 --- a/utils.py +++ b/utils.py @@ -1,7 +1,6 @@ import os from typing import Dict, Tuple, Union, Optional -from accelerate import load_checkpoint_and_dispatch from torch.nn import Module from transformers import AutoModel, AutoTokenizer from transformers.tokenization_utils import PreTrainedTokenizer @@ -40,6 +39,8 @@ def auto_configure_device_map(num_gpus: int) -> Dict[str, int]: def load_model_on_gpus(checkpoint_path: Union[str, os.PathLike], num_gpus: int = 2, multi_gpu_model_cache_dir: Union[str, os.PathLike] = "./temp_model_dir", tokenizer: Optional[PreTrainedTokenizer] = None, **kwargs) -> Module: + from accelerate import load_checkpoint_and_dispatch + model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs) model = model.eval()