diff --git a/ptuning/README.md b/ptuning/README.md index f92a328..30c15d2 100644 --- a/ptuning/README.md +++ b/ptuning/README.md @@ -8,7 +8,7 @@ ## 软件依赖 运行微调需要4.27.1版本的`transformers`。除 ChatGLM-6B 的依赖之外,还需要安装以下依赖 ``` -pip install rouge_chinese nltk jieba datasets +pip install rouge_chinese nltk jieba datasets filelock ``` ## 使用方法 diff --git a/ptuning/main.py b/ptuning/main.py index 49b08b0..f4508af 100644 --- a/ptuning/main.py +++ b/ptuning/main.py @@ -45,6 +45,8 @@ from trainer_seq2seq import Seq2SeqTrainer from arguments import ModelArguments, DataTrainingArguments +from filelock import FileLock + logger = logging.getLogger(__name__) def main(): @@ -122,7 +124,8 @@ def main(): new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) else: - model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True) + with FileLock("model.lock"): + model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True) if model_args.quantization_bit is not None: print(f"Quantized to {model_args.quantization_bit} bit")