[Shardformer] Add parallel output for shardformer models(bloom, falcon) (#5702)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* add parallel cross entropy output for falcon model & fix some typos in bloom.py

* fix module name error, self.model -> self.transformers in bloom, falcon model

* Fix the overflow bug of distributed cross entropy loss function when training with fp16

* add dtype to parallel cross entropy loss function

* fix dtype related typos adn prettify the loss.py

* fix grad dtype and update dtype mismatch error

* fix typo bugs
pull/5746/head
Haze188 6 months ago committed by GitHub
parent 9d83c6d715
commit 22ce873c3f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -22,6 +22,7 @@ class DistCrossEntropy(Function):
ignore_index: int,
process_group: ProcessGroup,
vocab_size: int,
dtype=torch.float32,
):
r"""
Calculate the cross entropy loss before gather, the origin loss function is as follows:
@ -34,7 +35,7 @@ class DistCrossEntropy(Function):
Args:
vocab_logits (:class:`torch.Tensor`): The logits of the vocabulary, shape is
[batch_size, seq_len, vocab_size]
labels (:class:`torch.Tensor`): The labels of the vocabulary, shape is
target (:class:`torch.Tensor`): The labels of the vocabulary, shape is
[batch_size, seq_len]
Returns:
@ -86,7 +87,7 @@ class DistCrossEntropy(Function):
dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM, group=process_group)
exp_logits = vocab_logits
torch.exp(vocab_logits, out=exp_logits)
sum_exp_logits = torch.sum(exp_logits, dim=-1)
sum_exp_logits = torch.sum(exp_logits, dim=-1, dtype=torch.float32)
dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=process_group)
# calculate the loss
@ -97,9 +98,10 @@ class DistCrossEntropy(Function):
loss = torch.sum(loss).div_(num_non_zero)
# calculate the softmax
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
exp_logits = exp_logits.div(sum_exp_logits.unsqueeze(dim=-1)).to(dtype)
exp_logits[target == ignore_index] = 0.0
ctx.save_for_backward(exp_logits, mask, masked_target_1d)
ctx.dtype = dtype
return loss
@ -114,11 +116,11 @@ class DistCrossEntropy(Function):
partion_vocab_size = grad_logits.shape[-1]
grad_logits_2d = grad_logits.view(-1, partion_vocab_size)
update = 1.0 - mask.view(-1).float()
update = 1.0 - mask.view(-1).float().to(ctx.dtype)
grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update
grad_logits.mul_(grad_output.unsqueeze(dim=-1))
return grad_logits, None, None, None, None
return grad_logits, None, None, None, None, None
def cross_entropy_1d(
@ -127,5 +129,6 @@ def cross_entropy_1d(
ignore_index: int = -100,
process_group: ProcessGroup = None,
vocab_size: int = None,
dtype: torch.dtype = None,
) -> torch.Tensor:
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size)
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype)

@ -10,6 +10,7 @@ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_m
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
@ -27,6 +28,8 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.shard import ShardConfig
from ..layer import cross_entropy_1d
logger = logging.get_logger(__name__)
@ -354,7 +357,7 @@ class BloomPipelineForwards:
past_key_values = None
if stage_manager.is_last_stage():
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
lm_logits = self.lm_head(hidden_states).contiguous()
loss = None
if labels is not None:
@ -365,10 +368,21 @@ class BloomPipelineForwards:
shift_labels = labels[..., 1:].contiguous()
batch_size, seq_length, vocab_size = shift_logits.shape
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
)
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
new_vocab_size = lm_logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
shift_labels = shift_labels.view(-1)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.transformer.dtype,
)
else:
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
loss = loss_fct(shift_logits, shift_labels.view(-1))
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
@ -1065,3 +1079,79 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
)
return forward
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
from transformers import BloomForCausalLM
def forward(
self: BloomForCausalLM,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
past_key_values = None
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
new_vocab_size = lm_logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
shift_labels = shift_labels.view(-1)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.transformer.dtype,
)
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
return forward

@ -14,6 +14,7 @@ from transformers.modeling_attn_mask_utils import (
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
@ -31,6 +32,8 @@ from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.shard import ShardConfig
from ..layer import cross_entropy_1d
def build_falcon_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor:
def build_falcon_alibi_tensor(
@ -437,14 +440,28 @@ class FalconPipelineForwards:
loss = None
if labels is not None:
# Shift so that tokens < n predict n
labels = labels.to(lm_logits.device)
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
batch_size, seq_length, vocab_size = shift_logits.shape
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
)
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
new_vocab_size = shift_logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
shift_labels = shift_labels.view(-1)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.transformer.dtype,
)
else:
loss = loss_fct(
shift_logits.view(batch_size * seq_length, vocab_size),
shift_labels.view(batch_size * seq_length),
)
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
@ -747,3 +764,79 @@ class FalconPipelineForwards:
else:
hidden_states = outputs.get("hidden_states")
return {"hidden_states": hidden_states}
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
from transformers import FalconForCausalLM
def forward(
self: FalconForCausalLM,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
past_key_values = None
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
labels = labels.to(lm_logits.device)
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
batch_size, seq_length, vocab_size = shift_logits.shape
# Flatten the tokens
new_vocab_size = shift_logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
shift_labels = shift_labels.view(-1)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.transformer.dtype,
)
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
return forward

@ -389,6 +389,7 @@ class GPT2PipelineForwards:
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.transformer.dtype,
)
else:
loss = loss_fct(shift_logits, shift_labels)
@ -1294,6 +1295,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.transformer.dtype,
)
if not return_dict:

@ -332,6 +332,7 @@ class LlamaPipelineForwards:
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.model.dtype,
)
else:
shift_logits = shift_logits.view(-1, self.config.vocab_size)
@ -768,6 +769,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.model.dtype,
)
if not return_dict:

@ -281,6 +281,7 @@ class MistralForwards:
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.model.dtype,
)
else:
shift_logits = shift_logits.view(-1, self.config.vocab_size)
@ -701,6 +702,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.model.dtype,
)
if not return_dict:

@ -348,6 +348,7 @@ class OPTPipelineForwards:
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.model.decoder.dtype,
)
else:
loss_fct = CrossEntropyLoss()
@ -988,6 +989,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.model.decoder.dtype,
)
if not return_dict:

@ -16,6 +16,7 @@ from ..modeling.bloom import (
get_jit_fused_bloom_attention_forward,
get_jit_fused_bloom_gelu_forward,
get_jit_fused_bloom_mlp_forward,
get_lm_forward_with_dist_cross_entropy,
)
from ..modeling.jit import get_dropout_add_func, get_jit_fused_dropout_add_func, get_jit_fused_gelu_forward_func
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
@ -287,12 +288,18 @@ class BloomForCausalLMPolicy(BloomPolicy):
suffix="lm_head",
target_module=col_nn.VocabParallelLMHead1D,
kwargs=dict(
gather_output=True, make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by
gather_output=not self.shard_config.parallel_output,
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by,
),
),
policy=policy,
target_key=BloomForCausalLM,
)
if self.shard_config.parallel_output:
method_replacement = {"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=BloomForCausalLM
)
else:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(

@ -7,7 +7,12 @@ from torch.nn import Module
import colossalai.shardformer.layer as col_nn
from ..modeling.falcon import FalconPipelineForwards, build_falcon_alibi_tensor_fn, get_tp_falcon_decoder_layer_forward
from ..modeling.falcon import (
FalconPipelineForwards,
build_falcon_alibi_tensor_fn,
get_lm_forward_with_dist_cross_entropy,
get_tp_falcon_decoder_layer_forward,
)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["FalconPolicy"]
@ -233,12 +238,19 @@ class FalconForCausalLMPolicy(FalconPolicy):
suffix="lm_head",
target_module=col_nn.VocabParallelLMHead1D,
kwargs=dict(
gather_output=True, make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by
gather_output=not self.shard_config.parallel_output,
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by,
),
),
policy=policy,
target_key=FalconForCausalLM,
)
if self.shard_config.parallel_output:
method_replacement = {"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=FalconForCausalLM
)
else:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(

Loading…
Cancel
Save