diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index 6d99efc19..a6d19edf5 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -22,6 +22,7 @@ class DistCrossEntropy(Function): ignore_index: int, process_group: ProcessGroup, vocab_size: int, + dtype=torch.float32, ): r""" Calculate the cross entropy loss before gather, the origin loss function is as follows: @@ -34,7 +35,7 @@ class DistCrossEntropy(Function): Args: vocab_logits (:class:`torch.Tensor`): The logits of the vocabulary, shape is [batch_size, seq_len, vocab_size] - labels (:class:`torch.Tensor`): The labels of the vocabulary, shape is + target (:class:`torch.Tensor`): The labels of the vocabulary, shape is [batch_size, seq_len] Returns: @@ -86,7 +87,7 @@ class DistCrossEntropy(Function): dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM, group=process_group) exp_logits = vocab_logits torch.exp(vocab_logits, out=exp_logits) - sum_exp_logits = torch.sum(exp_logits, dim=-1) + sum_exp_logits = torch.sum(exp_logits, dim=-1, dtype=torch.float32) dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=process_group) # calculate the loss @@ -97,9 +98,10 @@ class DistCrossEntropy(Function): loss = torch.sum(loss).div_(num_non_zero) # calculate the softmax - exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + exp_logits = exp_logits.div(sum_exp_logits.unsqueeze(dim=-1)).to(dtype) exp_logits[target == ignore_index] = 0.0 ctx.save_for_backward(exp_logits, mask, masked_target_1d) + ctx.dtype = dtype return loss @@ -114,11 +116,11 @@ class DistCrossEntropy(Function): partion_vocab_size = grad_logits.shape[-1] grad_logits_2d = grad_logits.view(-1, partion_vocab_size) - update = 1.0 - mask.view(-1).float() + update = 1.0 - mask.view(-1).float().to(ctx.dtype) grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update grad_logits.mul_(grad_output.unsqueeze(dim=-1)) - return grad_logits, None, None, None, None + return grad_logits, None, None, None, None, None def cross_entropy_1d( @@ -127,5 +129,6 @@ def cross_entropy_1d( ignore_index: int = -100, process_group: ProcessGroup = None, vocab_size: int = None, + dtype: torch.dtype = None, ) -> torch.Tensor: - return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size) + return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype) diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index c4f326364..bf74d0833 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -10,6 +10,7 @@ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_m from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, + CausalLMOutputWithPast, QuestionAnsweringModelOutput, SequenceClassifierOutputWithPast, TokenClassifierOutput, @@ -27,6 +28,8 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.shard import ShardConfig +from ..layer import cross_entropy_1d + logger = logging.get_logger(__name__) @@ -354,7 +357,7 @@ class BloomPipelineForwards: past_key_values = None if stage_manager.is_last_stage(): hidden_states = transformer_outputs[0] - lm_logits = self.lm_head(hidden_states) + lm_logits = self.lm_head(hidden_states).contiguous() loss = None if labels is not None: @@ -365,10 +368,21 @@ class BloomPipelineForwards: shift_labels = labels[..., 1:].contiguous() batch_size, seq_length, vocab_size = shift_logits.shape # Flatten the tokens - loss_fct = CrossEntropyLoss() - loss = loss_fct( - shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length) - ) + if shard_config.enable_tensor_parallelism and shard_config.parallel_output: + new_vocab_size = lm_logits.shape[-1] + shift_logits = shift_logits.view(-1, new_vocab_size) + shift_labels = shift_labels.view(-1) + loss = cross_entropy_1d( + shift_logits, + shift_labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=self.lm_head.out_features, + dtype=self.transformer.dtype, + ) + else: + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + loss = loss_fct(shift_logits, shift_labels.view(-1)) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] @@ -1065,3 +1079,79 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig): ) return forward + + +def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): + from transformers import BloomForCausalLM + + def forward( + self: BloomForCausalLM, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + past_key_values = None + hidden_states = transformer_outputs[0] + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + new_vocab_size = lm_logits.shape[-1] + shift_logits = shift_logits.view(-1, new_vocab_size) + shift_labels = shift_labels.view(-1) + loss = cross_entropy_1d( + shift_logits, + shift_labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=self.lm_head.out_features, + dtype=self.transformer.dtype, + ) + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + return forward diff --git a/colossalai/shardformer/modeling/falcon.py b/colossalai/shardformer/modeling/falcon.py index df3b09c71..a43bdf481 100644 --- a/colossalai/shardformer/modeling/falcon.py +++ b/colossalai/shardformer/modeling/falcon.py @@ -14,6 +14,7 @@ from transformers.modeling_attn_mask_utils import ( from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, + CausalLMOutputWithPast, QuestionAnsweringModelOutput, SequenceClassifierOutputWithPast, TokenClassifierOutput, @@ -31,6 +32,8 @@ from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.shard import ShardConfig +from ..layer import cross_entropy_1d + def build_falcon_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor: def build_falcon_alibi_tensor( @@ -437,14 +440,28 @@ class FalconPipelineForwards: loss = None if labels is not None: # Shift so that tokens < n predict n + labels = labels.to(lm_logits.device) shift_logits = lm_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() batch_size, seq_length, vocab_size = shift_logits.shape # Flatten the tokens loss_fct = CrossEntropyLoss() - loss = loss_fct( - shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length) - ) + if shard_config.enable_tensor_parallelism and shard_config.parallel_output: + new_vocab_size = shift_logits.shape[-1] + shift_logits = shift_logits.view(-1, new_vocab_size) + shift_labels = shift_labels.view(-1) + loss = cross_entropy_1d( + shift_logits, + shift_labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=self.lm_head.out_features, + dtype=self.transformer.dtype, + ) + else: + loss = loss_fct( + shift_logits.view(batch_size * seq_length, vocab_size), + shift_labels.view(batch_size * seq_length), + ) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] @@ -747,3 +764,79 @@ class FalconPipelineForwards: else: hidden_states = outputs.get("hidden_states") return {"hidden_states": hidden_states} + + +def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): + from transformers import FalconForCausalLM + + def forward( + self: FalconForCausalLM, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + past_key_values = None + hidden_states = transformer_outputs[0] + lm_logits = self.lm_head(hidden_states) + loss = None + if labels is not None: + # Shift so that tokens < n predict n + labels = labels.to(lm_logits.device) + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + batch_size, seq_length, vocab_size = shift_logits.shape + # Flatten the tokens + new_vocab_size = shift_logits.shape[-1] + shift_logits = shift_logits.view(-1, new_vocab_size) + shift_labels = shift_labels.view(-1) + loss = cross_entropy_1d( + shift_logits, + shift_labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=self.lm_head.out_features, + dtype=self.transformer.dtype, + ) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + return forward diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index bfa995645..c49458dbd 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -389,6 +389,7 @@ class GPT2PipelineForwards: shift_labels, process_group=shard_config.tensor_parallel_process_group, vocab_size=self.lm_head.out_features, + dtype=self.transformer.dtype, ) else: loss = loss_fct(shift_logits, shift_labels) @@ -1294,6 +1295,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): shift_labels, process_group=shard_config.tensor_parallel_process_group, vocab_size=self.lm_head.out_features, + dtype=self.transformer.dtype, ) if not return_dict: diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 8a6a7cf17..d6f10ffaf 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -332,6 +332,7 @@ class LlamaPipelineForwards: shift_labels, process_group=shard_config.tensor_parallel_process_group, vocab_size=self.lm_head.out_features, + dtype=self.model.dtype, ) else: shift_logits = shift_logits.view(-1, self.config.vocab_size) @@ -768,6 +769,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): shift_labels, process_group=shard_config.tensor_parallel_process_group, vocab_size=self.lm_head.out_features, + dtype=self.model.dtype, ) if not return_dict: diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py index 93da71abb..5f96ebe3d 100644 --- a/colossalai/shardformer/modeling/mistral.py +++ b/colossalai/shardformer/modeling/mistral.py @@ -281,6 +281,7 @@ class MistralForwards: shift_labels, process_group=shard_config.tensor_parallel_process_group, vocab_size=self.lm_head.out_features, + dtype=self.model.dtype, ) else: shift_logits = shift_logits.view(-1, self.config.vocab_size) @@ -701,6 +702,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): shift_labels, process_group=shard_config.tensor_parallel_process_group, vocab_size=self.lm_head.out_features, + dtype=self.model.dtype, ) if not return_dict: diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index 5282e2eaa..f10860fef 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -348,6 +348,7 @@ class OPTPipelineForwards: shift_labels, process_group=shard_config.tensor_parallel_process_group, vocab_size=self.lm_head.out_features, + dtype=self.model.decoder.dtype, ) else: loss_fct = CrossEntropyLoss() @@ -988,6 +989,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): shift_labels, process_group=shard_config.tensor_parallel_process_group, vocab_size=self.lm_head.out_features, + dtype=self.model.decoder.dtype, ) if not return_dict: diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 4f076d233..724a6b77c 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -16,6 +16,7 @@ from ..modeling.bloom import ( get_jit_fused_bloom_attention_forward, get_jit_fused_bloom_gelu_forward, get_jit_fused_bloom_mlp_forward, + get_lm_forward_with_dist_cross_entropy, ) from ..modeling.jit import get_dropout_add_func, get_jit_fused_dropout_add_func, get_jit_fused_gelu_forward_func from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -287,12 +288,18 @@ class BloomForCausalLMPolicy(BloomPolicy): suffix="lm_head", target_module=col_nn.VocabParallelLMHead1D, kwargs=dict( - gather_output=True, make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by + gather_output=not self.shard_config.parallel_output, + make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by, ), ), policy=policy, target_key=BloomForCausalLM, ) + if self.shard_config.parallel_output: + method_replacement = {"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=BloomForCausalLM + ) else: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( diff --git a/colossalai/shardformer/policies/falcon.py b/colossalai/shardformer/policies/falcon.py index 23d6efbeb..e5c167337 100644 --- a/colossalai/shardformer/policies/falcon.py +++ b/colossalai/shardformer/policies/falcon.py @@ -7,7 +7,12 @@ from torch.nn import Module import colossalai.shardformer.layer as col_nn -from ..modeling.falcon import FalconPipelineForwards, build_falcon_alibi_tensor_fn, get_tp_falcon_decoder_layer_forward +from ..modeling.falcon import ( + FalconPipelineForwards, + build_falcon_alibi_tensor_fn, + get_lm_forward_with_dist_cross_entropy, + get_tp_falcon_decoder_layer_forward, +) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ["FalconPolicy"] @@ -233,12 +238,19 @@ class FalconForCausalLMPolicy(FalconPolicy): suffix="lm_head", target_module=col_nn.VocabParallelLMHead1D, kwargs=dict( - gather_output=True, make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by + gather_output=not self.shard_config.parallel_output, + make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by, ), ), policy=policy, target_key=FalconForCausalLM, ) + if self.shard_config.parallel_output: + method_replacement = {"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=FalconForCausalLM + ) + else: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription(