From c31d9c0984818849ac3ab00b2641ba727642920b Mon Sep 17 00:00:00 2001 From: saber <3082548039@qq.com> Date: Mon, 27 Mar 2023 22:11:57 +0800 Subject: [PATCH] Move import statement into function to avoid dependency --- utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/utils.py b/utils.py index 0ebc761..2bf13c2 100644 --- a/utils.py +++ b/utils.py @@ -1,7 +1,6 @@ import os from typing import Dict, Tuple, Union, Optional -from accelerate import load_checkpoint_and_dispatch from torch.nn import Module from transformers import AutoModel, AutoTokenizer from transformers.tokenization_utils import PreTrainedTokenizer @@ -40,6 +39,8 @@ def auto_configure_device_map(num_gpus: int) -> Dict[str, int]: def load_model_on_gpus(checkpoint_path: Union[str, os.PathLike], num_gpus: int = 2, multi_gpu_model_cache_dir: Union[str, os.PathLike] = "./temp_model_dir", tokenizer: Optional[PreTrainedTokenizer] = None, **kwargs) -> Module: + from accelerate import load_checkpoint_and_dispatch + model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs) model = model.eval()