[coati] Fix LlamaCritic (#3475)

* mv LlamaForCausalLM to LlamaModel

* rm unused imports

---------

Co-authored-by: gongenlei <gongenlei@baidu.com>
pull/3497/head
gongenlei 2023-04-07 11:39:09 +08:00 committed by GitHub
parent 8f2c55f9c9
commit a7ca297281
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 5 deletions

View File

@ -1,8 +1,7 @@
from typing import Optional from typing import Optional
import torch
import torch.nn as nn import torch.nn as nn
from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM from transformers import LlamaConfig, LlamaModel
from ..base import Critic from ..base import Critic
@ -28,11 +27,11 @@ class LlamaCritic(Critic):
**kwargs) -> None: **kwargs) -> None:
if pretrained is not None: if pretrained is not None:
model = LlamaForCausalLM.from_pretrained(pretrained) model = LlamaModel.from_pretrained(pretrained)
elif config is not None: elif config is not None:
model = LlamaForCausalLM(config) model = LlamaModel(config)
else: else:
model = LlamaForCausalLM(LlamaConfig()) model = LlamaModel(LlamaConfig())
if checkpoint: if checkpoint:
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable()