diff --git a/applications/Chat/coati/models/llama/llama_critic.py b/applications/Chat/coati/models/llama/llama_critic.py index cd565031e..dd9e5e7bf 100644 --- a/applications/Chat/coati/models/llama/llama_critic.py +++ b/applications/Chat/coati/models/llama/llama_critic.py @@ -1,8 +1,7 @@ from typing import Optional -import torch import torch.nn as nn -from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM +from transformers import LlamaConfig, LlamaModel from ..base import Critic @@ -28,11 +27,11 @@ class LlamaCritic(Critic): **kwargs) -> None: if pretrained is not None: - model = LlamaForCausalLM.from_pretrained(pretrained) + model = LlamaModel.from_pretrained(pretrained) elif config is not None: - model = LlamaForCausalLM(config) + model = LlamaModel(config) else: - model = LlamaForCausalLM(LlamaConfig()) + model = LlamaModel(LlamaConfig()) if checkpoint: model.gradient_checkpointing_enable()