mirror of https://github.com/hpcaitech/ColossalAI
[coati] Fix LlamaCritic (#3475)
* mv LlamaForCausalLM to LlamaModel * rm unused imports --------- Co-authored-by: gongenlei <gongenlei@baidu.com>pull/3497/head
parent
8f2c55f9c9
commit
a7ca297281
|
@ -1,8 +1,7 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM
|
||||
from transformers import LlamaConfig, LlamaModel
|
||||
|
||||
from ..base import Critic
|
||||
|
||||
|
@ -28,11 +27,11 @@ class LlamaCritic(Critic):
|
|||
**kwargs) -> None:
|
||||
|
||||
if pretrained is not None:
|
||||
model = LlamaForCausalLM.from_pretrained(pretrained)
|
||||
model = LlamaModel.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = LlamaForCausalLM(config)
|
||||
model = LlamaModel(config)
|
||||
else:
|
||||
model = LlamaForCausalLM(LlamaConfig())
|
||||
model = LlamaModel(LlamaConfig())
|
||||
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
|
Loading…
Reference in New Issue