mirror of https://github.com/hpcaitech/ColossalAI
[chatgpt] fix lora support for gpt (#3113)
* fix gpt-actor * fix gpt-critic * fix opt-criticpull/3116/head
parent
0aa92c0409
commit
0672b5afac
|
@ -14,12 +14,16 @@ class GPTActor(Actor):
|
||||||
pretrained (str): Pretrained model name or path.
|
pretrained (str): Pretrained model name or path.
|
||||||
config (GPT2Config): Model config.
|
config (GPT2Config): Model config.
|
||||||
checkpoint (bool): Enable gradient checkpointing.
|
checkpoint (bool): Enable gradient checkpointing.
|
||||||
|
lora_rank (int): Rank of the LoRa layer.
|
||||||
|
lora_train_bias (str): Bias training strategy for the LoRa layer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
pretrained: Optional[str] = None,
|
pretrained: Optional[str] = None,
|
||||||
config: Optional[GPT2Config] = None,
|
config: Optional[GPT2Config] = None,
|
||||||
checkpoint: bool = False) -> None:
|
checkpoint: bool = False,
|
||||||
|
lora_rank: int = 0,
|
||||||
|
lora_train_bias: str = 'none') -> None:
|
||||||
if pretrained is not None:
|
if pretrained is not None:
|
||||||
model = GPT2LMHeadModel.from_pretrained(pretrained)
|
model = GPT2LMHeadModel.from_pretrained(pretrained)
|
||||||
elif config is not None:
|
elif config is not None:
|
||||||
|
@ -28,4 +32,4 @@ class GPTActor(Actor):
|
||||||
model = GPT2LMHeadModel(GPT2Config())
|
model = GPT2LMHeadModel(GPT2Config())
|
||||||
if checkpoint:
|
if checkpoint:
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
super().__init__(model)
|
super().__init__(model, lora_rank, lora_train_bias)
|
||||||
|
|
|
@ -15,13 +15,16 @@ class GPTCritic(Critic):
|
||||||
pretrained (str): Pretrained model name or path.
|
pretrained (str): Pretrained model name or path.
|
||||||
config (GPT2Config): Model config.
|
config (GPT2Config): Model config.
|
||||||
checkpoint (bool): Enable gradient checkpointing.
|
checkpoint (bool): Enable gradient checkpointing.
|
||||||
|
lora_rank (int): Rank of the LO-RA decomposition.
|
||||||
|
lora_train_bias (str): LoRA bias training mode.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
pretrained: Optional[str] = None,
|
pretrained: Optional[str] = None,
|
||||||
config: Optional[GPT2Config] = None,
|
config: Optional[GPT2Config] = None,
|
||||||
checkpoint: bool = False,
|
checkpoint: bool = False,
|
||||||
**kwargs) -> None:
|
lora_rank: int = 0,
|
||||||
|
lora_train_bias: str = 'none') -> None:
|
||||||
if pretrained is not None:
|
if pretrained is not None:
|
||||||
model = GPT2Model.from_pretrained(pretrained)
|
model = GPT2Model.from_pretrained(pretrained)
|
||||||
elif config is not None:
|
elif config is not None:
|
||||||
|
@ -31,4 +34,4 @@ class GPTCritic(Critic):
|
||||||
if checkpoint:
|
if checkpoint:
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
value_head = nn.Linear(model.config.n_embd, 1)
|
value_head = nn.Linear(model.config.n_embd, 1)
|
||||||
super().__init__(model, value_head, **kwargs)
|
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
||||||
|
|
|
@ -34,5 +34,5 @@ class OPTCritic(Critic):
|
||||||
model = OPTModel(OPTConfig())
|
model = OPTModel(OPTConfig())
|
||||||
if checkpoint:
|
if checkpoint:
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
value_head = nn.Linear(model.config.word_embed_proj_dim, 1)
|
||||||
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
|
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
|
||||||
|
|
Loading…
Reference in New Issue