|
|
|
@ -16,7 +16,6 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
|
|
|
|
|
from colossalai.shardformer.shard import ShardConfig
|
|
|
|
|
|
|
|
|
|
from ..layer import cross_entropy_1d
|
|
|
|
|
from ..layer._operation import gather_forward_split_backward
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
|
|
|
|
@ -279,7 +278,7 @@ class LlamaPipelineForwards:
|
|
|
|
|
shift_labels = shift_labels.view(-1)
|
|
|
|
|
# Enable model parallelism
|
|
|
|
|
shift_labels = shift_labels.to(shift_logits.device)
|
|
|
|
|
if shard_config.enable_tensor_parallelism:
|
|
|
|
|
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
|
|
|
|
|
new_vocab_size = logits.shape[-1]
|
|
|
|
|
shift_logits = shift_logits.view(-1, new_vocab_size)
|
|
|
|
|
loss = cross_entropy_1d(
|
|
|
|
@ -289,9 +288,6 @@ class LlamaPipelineForwards:
|
|
|
|
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
|
|
|
|
loss = loss_fct(shift_logits, shift_labels)
|
|
|
|
|
|
|
|
|
|
if not shard_config.parallel_output:
|
|
|
|
|
logits = gather_forward_split_backward(logits, -1, shard_config.tensor_parallel_process_group)
|
|
|
|
|
|
|
|
|
|
if not return_dict:
|
|
|
|
|
output = (logits,) + outputs[1:]
|
|
|
|
|
return (loss,) + output if loss is not None else output
|
|
|
|
@ -578,23 +574,15 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|
|
|
|
# Shift so that tokens < n predict n
|
|
|
|
|
shift_logits = logits[..., :-1, :].contiguous()
|
|
|
|
|
shift_labels = labels[..., 1:].contiguous()
|
|
|
|
|
# Flatten the tokens
|
|
|
|
|
loss_fct = CrossEntropyLoss()
|
|
|
|
|
shift_labels = shift_labels.view(-1)
|
|
|
|
|
# Enable model parallelism
|
|
|
|
|
shift_labels = shift_labels.to(shift_logits.device)
|
|
|
|
|
if shard_config.enable_tensor_parallelism:
|
|
|
|
|
|
|
|
|
|
new_vocab_size = logits.shape[-1]
|
|
|
|
|
shift_logits = shift_logits.view(-1, new_vocab_size)
|
|
|
|
|
loss = cross_entropy_1d(
|
|
|
|
|
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
|
|
|
|
loss = loss_fct(shift_logits, shift_labels)
|
|
|
|
|
|
|
|
|
|
if not shard_config.parallel_output:
|
|
|
|
|
logits = gather_forward_split_backward(logits, -1, shard_config.tensor_parallel_process_group)
|
|
|
|
|
|
|
|
|
|
if not return_dict:
|
|
|
|
|
output = (logits,) + outputs[1:]
|
|
|
|
|