From a7ca2972810ac784754f0f31e21324687c03b324 Mon Sep 17 00:00:00 2001 From: gongenlei Date: Fri, 7 Apr 2023 11:39:09 +0800 Subject: [PATCH] [coati] Fix LlamaCritic (#3475) * mv LlamaForCausalLM to LlamaModel * rm unused imports --------- Co-authored-by: gongenlei --- applications/Chat/coati/models/llama/llama_critic.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/applications/Chat/coati/models/llama/llama_critic.py b/applications/Chat/coati/models/llama/llama_critic.py index cd565031e..dd9e5e7bf 100644 --- a/applications/Chat/coati/models/llama/llama_critic.py +++ b/applications/Chat/coati/models/llama/llama_critic.py @@ -1,8 +1,7 @@ from typing import Optional -import torch import torch.nn as nn -from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM +from transformers import LlamaConfig, LlamaModel from ..base import Critic @@ -28,11 +27,11 @@ class LlamaCritic(Critic): **kwargs) -> None: if pretrained is not None: - model = LlamaForCausalLM.from_pretrained(pretrained) + model = LlamaModel.from_pretrained(pretrained) elif config is not None: - model = LlamaForCausalLM(config) + model = LlamaModel(config) else: - model = LlamaForCausalLM(LlamaConfig()) + model = LlamaModel(LlamaConfig()) if checkpoint: model.gradient_checkpointing_enable()