|
|
|
@ -21,7 +21,9 @@ from transformers.utils import logging
|
|
|
|
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
|
|
|
|
from colossalai.shardformer.layer import ColoAttention
|
|
|
|
|
from colossalai.shardformer.shard import ShardConfig
|
|
|
|
|
|
|
|
|
|
from ..layer import cross_entropy_1d
|
|
|
|
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -351,7 +353,7 @@ class OPTPipelineForwards:
|
|
|
|
|
loss_fct = CrossEntropyLoss()
|
|
|
|
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
|
|
|
|
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not return_dict:
|
|
|
|
|
output = (logits,) + outputs[1:]
|
|
|
|
|
return (loss,) + output if loss is not None else output
|
|
|
|
@ -987,8 +989,8 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|
|
|
|
process_group=shard_config.tensor_parallel_process_group,
|
|
|
|
|
vocab_size=self.lm_head.out_features,
|
|
|
|
|
)
|
|
|
|
|
#loss_fct = CrossEntropyLoss()
|
|
|
|
|
#loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
|
|
|
|
|
# loss_fct = CrossEntropyLoss()
|
|
|
|
|
# loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
|
|
|
|
|
|
|
|
|
|
if not return_dict:
|
|
|
|
|
output = (logits,) + outputs[1:]
|
|
|
|
@ -1002,4 +1004,4 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|
|
|
|
attentions=outputs.attentions,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return forward
|
|
|
|
|
return forward
|
|
|
|
|