mirror of https://github.com/hpcaitech/ColossalAI
[coati] Fix LlamaCritic (#3475)
* mv LlamaForCausalLM to LlamaModel * rm unused imports --------- Co-authored-by: gongenlei <gongenlei@baidu.com>pull/3497/head
parent
8f2c55f9c9
commit
a7ca297281
|
@ -1,8 +1,7 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM
|
from transformers import LlamaConfig, LlamaModel
|
||||||
|
|
||||||
from ..base import Critic
|
from ..base import Critic
|
||||||
|
|
||||||
|
@ -28,11 +27,11 @@ class LlamaCritic(Critic):
|
||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
|
|
||||||
if pretrained is not None:
|
if pretrained is not None:
|
||||||
model = LlamaForCausalLM.from_pretrained(pretrained)
|
model = LlamaModel.from_pretrained(pretrained)
|
||||||
elif config is not None:
|
elif config is not None:
|
||||||
model = LlamaForCausalLM(config)
|
model = LlamaModel(config)
|
||||||
else:
|
else:
|
||||||
model = LlamaForCausalLM(LlamaConfig())
|
model = LlamaModel(LlamaConfig())
|
||||||
|
|
||||||
if checkpoint:
|
if checkpoint:
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
|
|
Loading…
Reference in New Issue