mirror of https://github.com/hpcaitech/ColossalAI
68 lines
2.2 KiB
Python
68 lines
2.2 KiB
Python
from typing import Any, Dict, List, Type
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import CrossEntropyLoss
|
|
from transformers import BertForMaskedLM
|
|
from transformers.models.bert.modeling_bert import MaskedLMOutput
|
|
|
|
from ..layer.dist_crossentropy import applyDistCrossEntropy
|
|
|
|
|
|
class BertForMaskedLM_(BertForMaskedLM):
|
|
|
|
def forward(
|
|
self,
|
|
input_ids=None,
|
|
attention_mask=None,
|
|
token_type_ids=None,
|
|
position_ids=None,
|
|
head_mask=None,
|
|
inputs_embeds=None,
|
|
encoder_hidden_states=None,
|
|
encoder_attention_mask=None,
|
|
labels=None,
|
|
output_attentions=None,
|
|
output_hidden_states=None,
|
|
return_dict=None,
|
|
**kwargs,
|
|
):
|
|
# print("[Inject OK] Injected forward method")
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
outputs = self.bert(
|
|
input_ids,
|
|
attention_mask=attention_mask,
|
|
token_type_ids=token_type_ids,
|
|
position_ids=position_ids,
|
|
head_mask=head_mask,
|
|
inputs_embeds=inputs_embeds,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
sequence_output = outputs[0]
|
|
prediction_scores = self.cls(sequence_output)
|
|
|
|
masked_lm_loss = None
|
|
|
|
if labels is not None:
|
|
masked_lm_loss = applyDistCrossEntropy(prediction_scores, labels)
|
|
# if labels is not None:
|
|
# loss_fct = CrossEntropyLoss() # -100 index = padding token
|
|
# masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
|
|
|
if not return_dict:
|
|
output = (prediction_scores,) + outputs[2:]
|
|
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
|
|
|
return MaskedLMOutput(
|
|
loss=masked_lm_loss,
|
|
logits=prediction_scores,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|