[coati] Fix LlamaCritic (#3475)

* mv LlamaForCausalLM to LlamaModel

* rm unused imports

---------

Co-authored-by: gongenlei <gongenlei@baidu.com>
pull/3497/head
gongenlei 2023-04-07 11:39:09 +08:00 committed by GitHub
parent 8f2c55f9c9
commit a7ca297281
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 5 deletions

View File

@ -1,8 +1,7 @@
from typing import Optional
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM
from transformers import LlamaConfig, LlamaModel
from ..base import Critic
@ -28,11 +27,11 @@ class LlamaCritic(Critic):
**kwargs) -> None:
if pretrained is not None:
model = LlamaForCausalLM.from_pretrained(pretrained)
model = LlamaModel.from_pretrained(pretrained)
elif config is not None:
model = LlamaForCausalLM(config)
model = LlamaModel(config)
else:
model = LlamaForCausalLM(LlamaConfig())
model = LlamaModel(LlamaConfig())
if checkpoint:
model.gradient_checkpointing_enable()