Remove unnecessary duplicated model loading.

pull/1359/head
Guoqiang QI 2023-08-01 20:42:41 +08:00 committed by GitHub
parent d835c4b001
commit dfca8661ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 4 deletions

View File

@ -136,17 +136,16 @@ def main():
config.pre_seq_len = model_args.pre_seq_len
config.prefix_projection = model_args.prefix_projection
model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
if model_args.ptuning_checkpoint is not None:
print(f"Loading prefix_encoder weight from {model_args.ptuning_checkpoint}")
model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
if k.startswith("transformer.prefix_encoder."):
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
else:
model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
if model_args.quantization_bit is not None:
print(f"Quantized to {model_args.quantization_bit} bit")