[chat]fix sft training for bloom, gpt and opt (#3418)

fix sft training for bloom, gpt and opt
pull/3424/head
Yuanchen 2023-04-04 09:46:23 +08:00 committed by GitHub
parent 638a07a7f9
commit b09adff724
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 9 additions and 0 deletions

View File

@ -33,3 +33,6 @@ class BLOOMLM(LM):
if checkpoint:
model.gradient_checkpointing_enable()
super().__init__(model, lora_rank, lora_train_bias)
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs)

View File

@ -33,3 +33,6 @@ class GPTLM(LM):
if checkpoint:
model.gradient_checkpointing_enable()
super().__init__(model, lora_rank, lora_train_bias)
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs)

View File

@ -33,3 +33,6 @@ class OPTLM(LM):
if checkpoint:
model.gradient_checkpointing_enable()
super().__init__(model, lora_rank, lora_train_bias)
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs)