mirror of https://github.com/hpcaitech/ColossalAI
[Shardformer] Add parallel output for shardformer models(bloom, falcon) (#5702)
* [pre-commit.ci] auto fixes from pre-commit.com hooks * add parallel cross entropy output for falcon model & fix some typos in bloom.py * fix module name error, self.model -> self.transformers in bloom, falcon model * Fix the overflow bug of distributed cross entropy loss function when training with fp16 * add dtype to parallel cross entropy loss function * fix dtype related typos adn prettify the loss.py * fix grad dtype and update dtype mismatch error * fix typo bugspull/5746/head
parent
9d83c6d715
commit
22ce873c3f
|
@ -22,6 +22,7 @@ class DistCrossEntropy(Function):
|
||||||
ignore_index: int,
|
ignore_index: int,
|
||||||
process_group: ProcessGroup,
|
process_group: ProcessGroup,
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
|
dtype=torch.float32,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Calculate the cross entropy loss before gather, the origin loss function is as follows:
|
Calculate the cross entropy loss before gather, the origin loss function is as follows:
|
||||||
|
@ -34,7 +35,7 @@ class DistCrossEntropy(Function):
|
||||||
Args:
|
Args:
|
||||||
vocab_logits (:class:`torch.Tensor`): The logits of the vocabulary, shape is
|
vocab_logits (:class:`torch.Tensor`): The logits of the vocabulary, shape is
|
||||||
[batch_size, seq_len, vocab_size]
|
[batch_size, seq_len, vocab_size]
|
||||||
labels (:class:`torch.Tensor`): The labels of the vocabulary, shape is
|
target (:class:`torch.Tensor`): The labels of the vocabulary, shape is
|
||||||
[batch_size, seq_len]
|
[batch_size, seq_len]
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -86,7 +87,7 @@ class DistCrossEntropy(Function):
|
||||||
dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM, group=process_group)
|
dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM, group=process_group)
|
||||||
exp_logits = vocab_logits
|
exp_logits = vocab_logits
|
||||||
torch.exp(vocab_logits, out=exp_logits)
|
torch.exp(vocab_logits, out=exp_logits)
|
||||||
sum_exp_logits = torch.sum(exp_logits, dim=-1)
|
sum_exp_logits = torch.sum(exp_logits, dim=-1, dtype=torch.float32)
|
||||||
dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=process_group)
|
dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=process_group)
|
||||||
|
|
||||||
# calculate the loss
|
# calculate the loss
|
||||||
|
@ -97,9 +98,10 @@ class DistCrossEntropy(Function):
|
||||||
loss = torch.sum(loss).div_(num_non_zero)
|
loss = torch.sum(loss).div_(num_non_zero)
|
||||||
|
|
||||||
# calculate the softmax
|
# calculate the softmax
|
||||||
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
|
exp_logits = exp_logits.div(sum_exp_logits.unsqueeze(dim=-1)).to(dtype)
|
||||||
exp_logits[target == ignore_index] = 0.0
|
exp_logits[target == ignore_index] = 0.0
|
||||||
ctx.save_for_backward(exp_logits, mask, masked_target_1d)
|
ctx.save_for_backward(exp_logits, mask, masked_target_1d)
|
||||||
|
ctx.dtype = dtype
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
@ -114,11 +116,11 @@ class DistCrossEntropy(Function):
|
||||||
partion_vocab_size = grad_logits.shape[-1]
|
partion_vocab_size = grad_logits.shape[-1]
|
||||||
grad_logits_2d = grad_logits.view(-1, partion_vocab_size)
|
grad_logits_2d = grad_logits.view(-1, partion_vocab_size)
|
||||||
|
|
||||||
update = 1.0 - mask.view(-1).float()
|
update = 1.0 - mask.view(-1).float().to(ctx.dtype)
|
||||||
grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update
|
grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update
|
||||||
|
|
||||||
grad_logits.mul_(grad_output.unsqueeze(dim=-1))
|
grad_logits.mul_(grad_output.unsqueeze(dim=-1))
|
||||||
return grad_logits, None, None, None, None
|
return grad_logits, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
def cross_entropy_1d(
|
def cross_entropy_1d(
|
||||||
|
@ -127,5 +129,6 @@ def cross_entropy_1d(
|
||||||
ignore_index: int = -100,
|
ignore_index: int = -100,
|
||||||
process_group: ProcessGroup = None,
|
process_group: ProcessGroup = None,
|
||||||
vocab_size: int = None,
|
vocab_size: int = None,
|
||||||
|
dtype: torch.dtype = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size)
|
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype)
|
||||||
|
|
|
@ -10,6 +10,7 @@ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_m
|
||||||
from transformers.modeling_outputs import (
|
from transformers.modeling_outputs import (
|
||||||
BaseModelOutputWithPastAndCrossAttentions,
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
CausalLMOutputWithCrossAttentions,
|
CausalLMOutputWithCrossAttentions,
|
||||||
|
CausalLMOutputWithPast,
|
||||||
QuestionAnsweringModelOutput,
|
QuestionAnsweringModelOutput,
|
||||||
SequenceClassifierOutputWithPast,
|
SequenceClassifierOutputWithPast,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
|
@ -27,6 +28,8 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
|
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
|
||||||
from colossalai.shardformer.shard import ShardConfig
|
from colossalai.shardformer.shard import ShardConfig
|
||||||
|
|
||||||
|
from ..layer import cross_entropy_1d
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -354,7 +357,7 @@ class BloomPipelineForwards:
|
||||||
past_key_values = None
|
past_key_values = None
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage():
|
||||||
hidden_states = transformer_outputs[0]
|
hidden_states = transformer_outputs[0]
|
||||||
lm_logits = self.lm_head(hidden_states)
|
lm_logits = self.lm_head(hidden_states).contiguous()
|
||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
|
@ -365,10 +368,21 @@ class BloomPipelineForwards:
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
batch_size, seq_length, vocab_size = shift_logits.shape
|
batch_size, seq_length, vocab_size = shift_logits.shape
|
||||||
# Flatten the tokens
|
# Flatten the tokens
|
||||||
loss_fct = CrossEntropyLoss()
|
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
|
||||||
loss = loss_fct(
|
new_vocab_size = lm_logits.shape[-1]
|
||||||
shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
|
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||||
)
|
shift_labels = shift_labels.view(-1)
|
||||||
|
loss = cross_entropy_1d(
|
||||||
|
shift_logits,
|
||||||
|
shift_labels,
|
||||||
|
process_group=shard_config.tensor_parallel_process_group,
|
||||||
|
vocab_size=self.lm_head.out_features,
|
||||||
|
dtype=self.transformer.dtype,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||||
|
loss = loss_fct(shift_logits, shift_labels.view(-1))
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (lm_logits,) + transformer_outputs[1:]
|
output = (lm_logits,) + transformer_outputs[1:]
|
||||||
|
@ -1065,3 +1079,79 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||||
)
|
)
|
||||||
|
|
||||||
return forward
|
return forward
|
||||||
|
|
||||||
|
|
||||||
|
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||||
|
from transformers import BloomForCausalLM
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self: BloomForCausalLM,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
labels: Optional[torch.Tensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
||||||
|
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
||||||
|
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
||||||
|
"""
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
transformer_outputs = self.transformer(
|
||||||
|
input_ids=input_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
past_key_values = None
|
||||||
|
hidden_states = transformer_outputs[0]
|
||||||
|
lm_logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# move labels to correct device to enable model parallelism
|
||||||
|
labels = labels.to(lm_logits.device)
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
new_vocab_size = lm_logits.shape[-1]
|
||||||
|
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
loss = cross_entropy_1d(
|
||||||
|
shift_logits,
|
||||||
|
shift_labels,
|
||||||
|
process_group=shard_config.tensor_parallel_process_group,
|
||||||
|
vocab_size=self.lm_head.out_features,
|
||||||
|
dtype=self.transformer.dtype,
|
||||||
|
)
|
||||||
|
if not return_dict:
|
||||||
|
output = (lm_logits,) + transformer_outputs[1:]
|
||||||
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=lm_logits,
|
||||||
|
past_key_values=transformer_outputs.past_key_values,
|
||||||
|
hidden_states=transformer_outputs.hidden_states,
|
||||||
|
attentions=transformer_outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
return forward
|
||||||
|
|
|
@ -14,6 +14,7 @@ from transformers.modeling_attn_mask_utils import (
|
||||||
from transformers.modeling_outputs import (
|
from transformers.modeling_outputs import (
|
||||||
BaseModelOutputWithPastAndCrossAttentions,
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
CausalLMOutputWithCrossAttentions,
|
CausalLMOutputWithCrossAttentions,
|
||||||
|
CausalLMOutputWithPast,
|
||||||
QuestionAnsweringModelOutput,
|
QuestionAnsweringModelOutput,
|
||||||
SequenceClassifierOutputWithPast,
|
SequenceClassifierOutputWithPast,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
|
@ -31,6 +32,8 @@ from transformers.utils import logging
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.shardformer.shard import ShardConfig
|
from colossalai.shardformer.shard import ShardConfig
|
||||||
|
|
||||||
|
from ..layer import cross_entropy_1d
|
||||||
|
|
||||||
|
|
||||||
def build_falcon_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor:
|
def build_falcon_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor:
|
||||||
def build_falcon_alibi_tensor(
|
def build_falcon_alibi_tensor(
|
||||||
|
@ -437,14 +440,28 @@ class FalconPipelineForwards:
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
# Shift so that tokens < n predict n
|
# Shift so that tokens < n predict n
|
||||||
|
labels = labels.to(lm_logits.device)
|
||||||
shift_logits = lm_logits[..., :-1, :].contiguous()
|
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
batch_size, seq_length, vocab_size = shift_logits.shape
|
batch_size, seq_length, vocab_size = shift_logits.shape
|
||||||
# Flatten the tokens
|
# Flatten the tokens
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
loss = loss_fct(
|
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
|
||||||
shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
|
new_vocab_size = shift_logits.shape[-1]
|
||||||
)
|
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
loss = cross_entropy_1d(
|
||||||
|
shift_logits,
|
||||||
|
shift_labels,
|
||||||
|
process_group=shard_config.tensor_parallel_process_group,
|
||||||
|
vocab_size=self.lm_head.out_features,
|
||||||
|
dtype=self.transformer.dtype,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
loss = loss_fct(
|
||||||
|
shift_logits.view(batch_size * seq_length, vocab_size),
|
||||||
|
shift_labels.view(batch_size * seq_length),
|
||||||
|
)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (lm_logits,) + transformer_outputs[1:]
|
output = (lm_logits,) + transformer_outputs[1:]
|
||||||
|
@ -747,3 +764,79 @@ class FalconPipelineForwards:
|
||||||
else:
|
else:
|
||||||
hidden_states = outputs.get("hidden_states")
|
hidden_states = outputs.get("hidden_states")
|
||||||
return {"hidden_states": hidden_states}
|
return {"hidden_states": hidden_states}
|
||||||
|
|
||||||
|
|
||||||
|
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||||
|
from transformers import FalconForCausalLM
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self: FalconForCausalLM,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
labels: Optional[torch.Tensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
||||||
|
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
||||||
|
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
||||||
|
"""
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
transformer_outputs = self.transformer(
|
||||||
|
input_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
past_key_values = None
|
||||||
|
hidden_states = transformer_outputs[0]
|
||||||
|
lm_logits = self.lm_head(hidden_states)
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
labels = labels.to(lm_logits.device)
|
||||||
|
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
batch_size, seq_length, vocab_size = shift_logits.shape
|
||||||
|
# Flatten the tokens
|
||||||
|
new_vocab_size = shift_logits.shape[-1]
|
||||||
|
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
loss = cross_entropy_1d(
|
||||||
|
shift_logits,
|
||||||
|
shift_labels,
|
||||||
|
process_group=shard_config.tensor_parallel_process_group,
|
||||||
|
vocab_size=self.lm_head.out_features,
|
||||||
|
dtype=self.transformer.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (lm_logits,) + transformer_outputs[1:]
|
||||||
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=lm_logits,
|
||||||
|
past_key_values=transformer_outputs.past_key_values,
|
||||||
|
hidden_states=transformer_outputs.hidden_states,
|
||||||
|
attentions=transformer_outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
return forward
|
||||||
|
|
|
@ -389,6 +389,7 @@ class GPT2PipelineForwards:
|
||||||
shift_labels,
|
shift_labels,
|
||||||
process_group=shard_config.tensor_parallel_process_group,
|
process_group=shard_config.tensor_parallel_process_group,
|
||||||
vocab_size=self.lm_head.out_features,
|
vocab_size=self.lm_head.out_features,
|
||||||
|
dtype=self.transformer.dtype,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
loss = loss_fct(shift_logits, shift_labels)
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
@ -1294,6 +1295,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||||
shift_labels,
|
shift_labels,
|
||||||
process_group=shard_config.tensor_parallel_process_group,
|
process_group=shard_config.tensor_parallel_process_group,
|
||||||
vocab_size=self.lm_head.out_features,
|
vocab_size=self.lm_head.out_features,
|
||||||
|
dtype=self.transformer.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
|
|
|
@ -332,6 +332,7 @@ class LlamaPipelineForwards:
|
||||||
shift_labels,
|
shift_labels,
|
||||||
process_group=shard_config.tensor_parallel_process_group,
|
process_group=shard_config.tensor_parallel_process_group,
|
||||||
vocab_size=self.lm_head.out_features,
|
vocab_size=self.lm_head.out_features,
|
||||||
|
dtype=self.model.dtype,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||||
|
@ -768,6 +769,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||||
shift_labels,
|
shift_labels,
|
||||||
process_group=shard_config.tensor_parallel_process_group,
|
process_group=shard_config.tensor_parallel_process_group,
|
||||||
vocab_size=self.lm_head.out_features,
|
vocab_size=self.lm_head.out_features,
|
||||||
|
dtype=self.model.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
|
|
|
@ -281,6 +281,7 @@ class MistralForwards:
|
||||||
shift_labels,
|
shift_labels,
|
||||||
process_group=shard_config.tensor_parallel_process_group,
|
process_group=shard_config.tensor_parallel_process_group,
|
||||||
vocab_size=self.lm_head.out_features,
|
vocab_size=self.lm_head.out_features,
|
||||||
|
dtype=self.model.dtype,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||||
|
@ -701,6 +702,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||||
shift_labels,
|
shift_labels,
|
||||||
process_group=shard_config.tensor_parallel_process_group,
|
process_group=shard_config.tensor_parallel_process_group,
|
||||||
vocab_size=self.lm_head.out_features,
|
vocab_size=self.lm_head.out_features,
|
||||||
|
dtype=self.model.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
|
|
|
@ -348,6 +348,7 @@ class OPTPipelineForwards:
|
||||||
shift_labels,
|
shift_labels,
|
||||||
process_group=shard_config.tensor_parallel_process_group,
|
process_group=shard_config.tensor_parallel_process_group,
|
||||||
vocab_size=self.lm_head.out_features,
|
vocab_size=self.lm_head.out_features,
|
||||||
|
dtype=self.model.decoder.dtype,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
|
@ -988,6 +989,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||||
shift_labels,
|
shift_labels,
|
||||||
process_group=shard_config.tensor_parallel_process_group,
|
process_group=shard_config.tensor_parallel_process_group,
|
||||||
vocab_size=self.lm_head.out_features,
|
vocab_size=self.lm_head.out_features,
|
||||||
|
dtype=self.model.decoder.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
|
|
|
@ -16,6 +16,7 @@ from ..modeling.bloom import (
|
||||||
get_jit_fused_bloom_attention_forward,
|
get_jit_fused_bloom_attention_forward,
|
||||||
get_jit_fused_bloom_gelu_forward,
|
get_jit_fused_bloom_gelu_forward,
|
||||||
get_jit_fused_bloom_mlp_forward,
|
get_jit_fused_bloom_mlp_forward,
|
||||||
|
get_lm_forward_with_dist_cross_entropy,
|
||||||
)
|
)
|
||||||
from ..modeling.jit import get_dropout_add_func, get_jit_fused_dropout_add_func, get_jit_fused_gelu_forward_func
|
from ..modeling.jit import get_dropout_add_func, get_jit_fused_dropout_add_func, get_jit_fused_gelu_forward_func
|
||||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
@ -287,12 +288,18 @@ class BloomForCausalLMPolicy(BloomPolicy):
|
||||||
suffix="lm_head",
|
suffix="lm_head",
|
||||||
target_module=col_nn.VocabParallelLMHead1D,
|
target_module=col_nn.VocabParallelLMHead1D,
|
||||||
kwargs=dict(
|
kwargs=dict(
|
||||||
gather_output=True, make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by
|
gather_output=not self.shard_config.parallel_output,
|
||||||
|
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
policy=policy,
|
policy=policy,
|
||||||
target_key=BloomForCausalLM,
|
target_key=BloomForCausalLM,
|
||||||
)
|
)
|
||||||
|
if self.shard_config.parallel_output:
|
||||||
|
method_replacement = {"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}
|
||||||
|
self.append_or_create_method_replacement(
|
||||||
|
description=method_replacement, policy=policy, target_key=BloomForCausalLM
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.append_or_create_submodule_replacement(
|
self.append_or_create_submodule_replacement(
|
||||||
description=SubModuleReplacementDescription(
|
description=SubModuleReplacementDescription(
|
||||||
|
|
|
@ -7,7 +7,12 @@ from torch.nn import Module
|
||||||
|
|
||||||
import colossalai.shardformer.layer as col_nn
|
import colossalai.shardformer.layer as col_nn
|
||||||
|
|
||||||
from ..modeling.falcon import FalconPipelineForwards, build_falcon_alibi_tensor_fn, get_tp_falcon_decoder_layer_forward
|
from ..modeling.falcon import (
|
||||||
|
FalconPipelineForwards,
|
||||||
|
build_falcon_alibi_tensor_fn,
|
||||||
|
get_lm_forward_with_dist_cross_entropy,
|
||||||
|
get_tp_falcon_decoder_layer_forward,
|
||||||
|
)
|
||||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
__all__ = ["FalconPolicy"]
|
__all__ = ["FalconPolicy"]
|
||||||
|
@ -233,12 +238,19 @@ class FalconForCausalLMPolicy(FalconPolicy):
|
||||||
suffix="lm_head",
|
suffix="lm_head",
|
||||||
target_module=col_nn.VocabParallelLMHead1D,
|
target_module=col_nn.VocabParallelLMHead1D,
|
||||||
kwargs=dict(
|
kwargs=dict(
|
||||||
gather_output=True, make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by
|
gather_output=not self.shard_config.parallel_output,
|
||||||
|
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
policy=policy,
|
policy=policy,
|
||||||
target_key=FalconForCausalLM,
|
target_key=FalconForCausalLM,
|
||||||
)
|
)
|
||||||
|
if self.shard_config.parallel_output:
|
||||||
|
method_replacement = {"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}
|
||||||
|
self.append_or_create_method_replacement(
|
||||||
|
description=method_replacement, policy=policy, target_key=FalconForCausalLM
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.append_or_create_submodule_replacement(
|
self.append_or_create_submodule_replacement(
|
||||||
description=SubModuleReplacementDescription(
|
description=SubModuleReplacementDescription(
|
||||||
|
|
Loading…
Reference in New Issue