Merge branch 'main' of github.com:THUDM/ChatGLM-6B

pull/621/head
duzx16 2 years ago
commit c6294ab3fd

@ -11,6 +11,8 @@
* [ChatGLM-Finetuning](https://github.com/liucongg/ChatGLM-Finetuning)基于ChatGLM-6B模型进行下游具体任务微调涉及Freeze、Lora、P-tuning等并进行实验效果对比。 * [ChatGLM-Finetuning](https://github.com/liucongg/ChatGLM-Finetuning)基于ChatGLM-6B模型进行下游具体任务微调涉及Freeze、Lora、P-tuning等并进行实验效果对比。
* [InstructGLM](https://github.com/yanqiangmiffy/InstructGLM)基于ChatGLM-6B进行指令学习汇总开源中英文指令数据基于Lora进行指令数据微调开放了Alpaca、Belle微调后的Lora权重修复web_demo重复问题 * [InstructGLM](https://github.com/yanqiangmiffy/InstructGLM)基于ChatGLM-6B进行指令学习汇总开源中英文指令数据基于Lora进行指令数据微调开放了Alpaca、Belle微调后的Lora权重修复web_demo重复问题
* [ChatGLM-web](https://github.com/NCZkevin/chatglm-web)基于FastAPI和Vue3搭建的ChatGLM演示网站(支持chatglm流式输出、前端调整模型参数、上下文选择、保存图片、知识库问答等功能) * [ChatGLM-web](https://github.com/NCZkevin/chatglm-web)基于FastAPI和Vue3搭建的ChatGLM演示网站(支持chatglm流式输出、前端调整模型参数、上下文选择、保存图片、知识库问答等功能)
* [glm-bot](https://github.com/initialencounter/glm-bot)将ChatGLM接入Koishi可在各大聊天平台上调用ChatGLM
以下是部分针对本项目的教程/文档: 以下是部分针对本项目的教程/文档:
* [Windows部署文档](https://github.com/ZhangErling/ChatGLM-6B/blob/main/deployment_windows.md) * [Windows部署文档](https://github.com/ZhangErling/ChatGLM-6B/blob/main/deployment_windows.md)
* [ChatGLM-6B 的部署与微调教程 @ModelWhale平台](https://www.heywhale.com/mw/project/6436d82948f7da1fee2be59e)

@ -354,6 +354,7 @@ def main():
tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=data_collator, data_collator=data_collator,
compute_metrics=compute_metrics if training_args.predict_with_generate else None, compute_metrics=compute_metrics if training_args.predict_with_generate else None,
save_prefixencoder=model_args.pre_seq_len is not None
) )
# Training # Training

@ -317,7 +317,9 @@ class Trainer:
callbacks: Optional[List[TrainerCallback]] = None, callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
save_prefixencoder: bool = False,
): ):
self.save_prefixencoder = save_prefixencoder
if args is None: if args is None:
output_dir = "tmp_trainer" output_dir = "tmp_trainer"
logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.") logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.")
@ -2825,12 +2827,17 @@ class Trainer:
state_dict = self.model.state_dict() state_dict = self.model.state_dict()
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else: else:
if self.save_prefixencoder:
print("Saving PrefixEncoder")
state_dict = self.model.state_dict() state_dict = self.model.state_dict()
filtered_state_dict = {} filtered_state_dict = {}
for k, v in self.model.named_parameters(): for k, v in self.model.named_parameters():
if v.requires_grad: if v.requires_grad:
filtered_state_dict[k] = state_dict[k] filtered_state_dict[k] = state_dict[k]
self.model.save_pretrained(output_dir, state_dict=filtered_state_dict) self.model.save_pretrained(output_dir, state_dict=filtered_state_dict)
else:
print("Saving the whole model")
self.model.save_pretrained(output_dir, state_dict=state_dict)
if self.tokenizer is not None: if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir) self.tokenizer.save_pretrained(output_dir)

Loading…
Cancel
Save