mirror of https://github.com/hpcaitech/ColossalAI
[Feature] Enable PP + SP for llama (#5868)
* fix cross-PP-stage position id length diff bug * fix typo * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * use a one cross entropy func for all shardformer models --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/5850/head
parent
66abf1c6e8
commit
fbf33ecd01
|
@ -1205,6 +1205,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
and self.enable_sequence_parallelism
|
||||
and self.sequence_parallelism_mode == "all_to_all"
|
||||
)
|
||||
# sync gradients across DP * SP ranks
|
||||
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
|
||||
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
|
||||
else:
|
||||
|
|
|
@ -3,7 +3,7 @@ from .attn import AttnMaskType, ColoAttention
|
|||
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
|
||||
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
|
||||
from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D
|
||||
from .loss import cross_entropy_1d
|
||||
from .loss import cross_entropy_1d, dist_cross_entropy
|
||||
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
|
||||
from .parallel_module import ParallelModule
|
||||
from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
|
||||
|
@ -18,6 +18,7 @@ __all__ = [
|
|||
"DropoutForParallelInput",
|
||||
"DropoutForReplicatedInput",
|
||||
"cross_entropy_1d",
|
||||
"dist_cross_entropy",
|
||||
"BaseLayerNorm",
|
||||
"LayerNorm",
|
||||
"RMSNorm",
|
||||
|
|
|
@ -2,8 +2,11 @@ import torch
|
|||
import torch.distributed as dist
|
||||
from torch.autograd import Function
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
__all__ = ["DistCrossEntropy", "cross_entropy_1d"]
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
__all__ = ["DistCrossEntropy", "cross_entropy_1d", "dist_cross_entropy"]
|
||||
|
||||
|
||||
class DistCrossEntropy(Function):
|
||||
|
@ -132,3 +135,43 @@ def cross_entropy_1d(
|
|||
dtype: torch.dtype = None,
|
||||
) -> torch.Tensor:
|
||||
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype)
|
||||
|
||||
|
||||
def dist_cross_entropy(
|
||||
labels: torch.Tensor,
|
||||
logits: torch.Tensor,
|
||||
shard_config: ShardConfig,
|
||||
out_features: int,
|
||||
vocab_size: int,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Helper to compute cross entropy loss for most shardformer models,
|
||||
compatible with PP, TP and SP.
|
||||
"""
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_labels = shift_labels.view(-1)
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
|
||||
# Cross entropy with all-reduce for TP
|
||||
new_vocab_size = logits.shape[-1]
|
||||
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits,
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=out_features,
|
||||
dtype=dtype,
|
||||
)
|
||||
else:
|
||||
# NOTE if use TP and not parallel_output, the output is gathered.
|
||||
# see VocabParallelLMHead1D
|
||||
shift_logits = shift_logits.view(-1, vocab_size)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
return loss
|
||||
|
|
|
@ -28,7 +28,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
|
|||
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from ..layer import cross_entropy_1d
|
||||
from ..layer import dist_cross_entropy
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
@ -359,30 +359,14 @@ class BloomPipelineForwards:
|
|||
hidden_states = transformer_outputs[0]
|
||||
lm_logits = self.lm_head(hidden_states).contiguous()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(lm_logits.device)
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
batch_size, seq_length, vocab_size = shift_logits.shape
|
||||
# Flatten the tokens
|
||||
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
|
||||
new_vocab_size = lm_logits.shape[-1]
|
||||
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits,
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=self.lm_head.out_features,
|
||||
dtype=self.transformer.dtype,
|
||||
)
|
||||
else:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
loss = loss_fct(shift_logits, shift_labels.view(-1))
|
||||
loss = dist_cross_entropy(
|
||||
labels,
|
||||
lm_logits,
|
||||
shard_config,
|
||||
self.lm_head.out_features,
|
||||
self.config.vocab_size,
|
||||
self.transformer.dtype,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + transformer_outputs[1:]
|
||||
|
@ -1040,24 +1024,10 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||
hidden_states = transformer_outputs[0]
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(lm_logits.device)
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
new_vocab_size = lm_logits.shape[-1]
|
||||
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits,
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=self.lm_head.out_features,
|
||||
dtype=self.transformer.dtype,
|
||||
)
|
||||
loss = dist_cross_entropy(
|
||||
labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
|
|
@ -5,7 +5,6 @@ from typing import List, Optional, Tuple, Union
|
|||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers.cache_utils import Cache, DynamicCache
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from transformers.models.cohere.modeling_cohere import (
|
||||
|
@ -25,7 +24,7 @@ from colossalai.shardformer.layer._operation import (
|
|||
)
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from ..layer import ColoAttention, cross_entropy_1d
|
||||
from ..layer import ColoAttention, dist_cross_entropy
|
||||
|
||||
|
||||
class CommandPipelineForwards:
|
||||
|
@ -300,29 +299,9 @@ class CommandPipelineForwards:
|
|||
logits = self.lm_head(hidden_states)
|
||||
logits = logits * self.logit_scale
|
||||
logits = logits.float()
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
|
||||
new_vocab_size = logits.shape[-1]
|
||||
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits,
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=self.lm_head.out_features,
|
||||
dtype=self.model.dtype,
|
||||
)
|
||||
else:
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
loss = dist_cross_entropy(
|
||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
@ -658,24 +637,14 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||
logits = self.lm_head(hidden_states)
|
||||
logits = logits * self.logit_scale
|
||||
logits = logits.float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
new_vocab_size = logits.shape[-1]
|
||||
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits,
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=self.lm_head.out_features,
|
||||
dtype=self.model.dtype,
|
||||
)
|
||||
loss = dist_cross_entropy(
|
||||
labels,
|
||||
logits,
|
||||
shard_config,
|
||||
self.lm_head.out_features,
|
||||
self.config.vocab_size,
|
||||
self.model.dtype,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
|
|
@ -25,7 +25,7 @@ from colossalai.shardformer.layer import ColoAttention
|
|||
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from ..layer import cross_entropy_1d
|
||||
from ..layer import dist_cross_entropy
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
@ -372,27 +372,9 @@ class GPT2PipelineForwards:
|
|||
|
||||
hidden_states = outputs[0]
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(lm_logits.device)
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
|
||||
shift_labels = shift_labels.view(-1)
|
||||
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits,
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=self.lm_head.out_features,
|
||||
dtype=self.transformer.dtype,
|
||||
)
|
||||
else:
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
loss = dist_cross_entropy(
|
||||
labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + outputs[1:]
|
||||
|
@ -1282,24 +1264,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||
hidden_states = transformer_outputs[0]
|
||||
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(lm_logits.device)
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
|
||||
shift_labels = shift_labels.view(-1)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits,
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=self.lm_head.out_features,
|
||||
dtype=self.transformer.dtype,
|
||||
)
|
||||
loss = dist_cross_entropy(
|
||||
labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + transformer_outputs[1:]
|
||||
|
|
|
@ -31,7 +31,7 @@ from colossalai.shardformer.layer._operation import (
|
|||
)
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from ..layer import ColoAttention, cross_entropy_1d
|
||||
from ..layer import ColoAttention, dist_cross_entropy
|
||||
|
||||
|
||||
class LlamaPipelineForwards:
|
||||
|
@ -86,13 +86,20 @@ class LlamaPipelineForwards:
|
|||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
batch_size, seq_length = input_shape
|
||||
device = hidden_states.device
|
||||
|
||||
# Support SP + PP
|
||||
sp_mode = shard_config.sequence_parallelism_mode
|
||||
sp_group = shard_config.sequence_parallel_process_group
|
||||
sp_size = shard_config.sequence_parallel_size
|
||||
if sp_mode == "all_to_all" and not stage_manager.is_first_stage():
|
||||
# For correct positions ids. The states will be gather along the seq dim in the attention layer later.
|
||||
seq_length *= sp_size
|
||||
|
||||
past_seen_tokens = 0
|
||||
if use_cache: # kept for BC (cache positions)
|
||||
if not isinstance(past_key_values, StaticCache):
|
||||
|
@ -101,7 +108,7 @@ class LlamaPipelineForwards:
|
|||
if cache_position is None:
|
||||
if isinstance(past_key_values, StaticCache):
|
||||
raise ValueError("cache_position is a required argument when using StaticCache.")
|
||||
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=device)
|
||||
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=device)
|
||||
|
||||
seq_length_with_past = seq_length + past_seen_tokens
|
||||
|
||||
|
@ -118,7 +125,6 @@ class LlamaPipelineForwards:
|
|||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
# embed positions, for the first stage, hidden_states is the input embeddings,
|
||||
# for the other stages, hidden_states is the output of the previous stage
|
||||
if shard_config.enable_flash_attention:
|
||||
|
@ -134,6 +140,13 @@ class LlamaPipelineForwards:
|
|||
else:
|
||||
attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position)
|
||||
|
||||
# Support SP + PP
|
||||
if stage_manager.is_first_stage():
|
||||
if sp_mode in ["ring", "split_gather"]:
|
||||
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size)
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
|
@ -196,6 +209,10 @@ class LlamaPipelineForwards:
|
|||
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = self.norm(hidden_states)
|
||||
if sp_mode == "ring" or sp_mode == "split_gather":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
|
@ -304,29 +321,9 @@ class LlamaPipelineForwards:
|
|||
if stage_manager.is_last_stage():
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
|
||||
new_vocab_size = logits.shape[-1]
|
||||
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits,
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=self.lm_head.out_features,
|
||||
dtype=self.model.dtype,
|
||||
)
|
||||
else:
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
loss = dist_cross_entropy(
|
||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
@ -529,7 +526,6 @@ def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None,
|
|||
)
|
||||
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
|
@ -804,24 +800,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
new_vocab_size = logits.shape[-1]
|
||||
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits,
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=self.lm_head.out_features,
|
||||
dtype=self.model.dtype,
|
||||
)
|
||||
|
||||
loss = dist_cross_entropy(
|
||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
|
||||
)
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
|
|
@ -19,7 +19,7 @@ from transformers.utils import logging
|
|||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from ..layer import ColoAttention, cross_entropy_1d
|
||||
from ..layer import ColoAttention, dist_cross_entropy
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
@ -275,29 +275,9 @@ class MistralForwards:
|
|||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
|
||||
new_vocab_size = logits.shape[-1]
|
||||
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits,
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=self.lm_head.out_features,
|
||||
dtype=self.model.dtype,
|
||||
)
|
||||
else:
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
loss = dist_cross_entropy(
|
||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
@ -708,23 +688,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
new_vocab_size = logits.shape[-1]
|
||||
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits,
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=self.lm_head.out_features,
|
||||
dtype=self.model.dtype,
|
||||
)
|
||||
loss = dist_cross_entropy(
|
||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
|
|
@ -22,7 +22,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
|
|||
from colossalai.shardformer.layer import ColoAttention
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from ..layer import cross_entropy_1d
|
||||
from ..layer import dist_cross_entropy
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
@ -330,30 +330,14 @@ class OPTPipelineForwards:
|
|||
)
|
||||
if stage_manager.is_last_stage():
|
||||
logits = self.lm_head(outputs[0]).contiguous()
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(logits.device)
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
|
||||
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
|
||||
new_vocab_size = logits.shape[-1]
|
||||
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits,
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=self.lm_head.out_features,
|
||||
dtype=self.model.decoder.dtype,
|
||||
)
|
||||
else:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
|
||||
loss = dist_cross_entropy(
|
||||
labels,
|
||||
logits,
|
||||
shard_config,
|
||||
self.lm_head.out_features,
|
||||
self.config.vocab_size,
|
||||
self.model.decoder.dtype,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
@ -971,26 +955,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||
)
|
||||
|
||||
logits = self.lm_head(outputs[0]).contiguous()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(logits.device)
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
new_vocab_size = logits.shape[-1]
|
||||
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits,
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=self.lm_head.out_features,
|
||||
dtype=self.model.decoder.dtype,
|
||||
)
|
||||
loss = dist_cross_entropy(
|
||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.decoder.dtype
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
|
|
@ -32,7 +32,7 @@ from transformers.utils import logging
|
|||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from ..layer import ColoAttention, cross_entropy_1d
|
||||
from ..layer import ColoAttention, dist_cross_entropy
|
||||
|
||||
|
||||
class Qwen2PipelineForwards:
|
||||
|
@ -317,25 +317,9 @@ class Qwen2PipelineForwards:
|
|||
if stage_manager.is_last_stage():
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
if shard_config.enable_tensor_parallelism:
|
||||
new_vocab_size = logits.shape[-1]
|
||||
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
|
||||
)
|
||||
else:
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
loss = dist_cross_entropy(
|
||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, logits.dtype
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
@ -737,26 +721,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
if shard_config.enable_tensor_parallelism:
|
||||
new_vocab_size = logits.shape[-1]
|
||||
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
|
||||
)
|
||||
else:
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
loss = dist_cross_entropy(
|
||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, logits.dtype
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import warnings
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, List, Union
|
||||
|
||||
|
@ -66,13 +65,6 @@ class LlamaPolicy(Policy):
|
|||
else:
|
||||
norm_cls = RMSNorm
|
||||
|
||||
if self.pipeline_stage_manager is not None:
|
||||
self.shard_config.enable_sequence_parallelism = False
|
||||
self.shard_config.enable_sequence_overlap = False
|
||||
self.shard_config.sequence_parallelism_mode = None
|
||||
warnings.warn(
|
||||
f"For llama, sequence parallelism is currently not compatible with pipeline parallelism, set to be False"
|
||||
)
|
||||
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
||||
sp_size = self.shard_config.sequence_parallel_size or None
|
||||
sp_group = self.shard_config.sequence_parallel_process_group or None
|
||||
|
|
|
@ -59,10 +59,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
if (
|
||||
booster.plugin.zero_stage in [1, 2]
|
||||
and booster.plugin.shard_config.enable_sequence_parallelism
|
||||
and booster.plugin.shard_config.pipeline_stage_manager is None
|
||||
and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all"
|
||||
):
|
||||
master2working = sharded_optimizer.get_master_to_working_map()
|
||||
for p1, p2 in zip(llama_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]):
|
||||
working_p = sharded_optimizer.master_to_working_param[id(p2)]
|
||||
working_p = master2working[id(p2)]
|
||||
grads = sharded_optimizer.get_partitioned_gradients_by_param_id(0, id(working_p))
|
||||
grad_index = (
|
||||
0
|
||||
|
@ -146,6 +148,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
{ # Ulysess + Flash attention
|
||||
"tp_size": 1,
|
||||
"pp_size": 2,
|
||||
"sp_size": 2,
|
||||
"num_microbatches": 2,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "all_to_all",
|
||||
"enable_flash_attention": True,
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 0,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{ # Test ring + Flash attention
|
||||
"tp_size": 2,
|
||||
"pp_size": 1,
|
||||
|
@ -159,19 +174,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{ # Ulysess + Flash attention
|
||||
"tp_size": 1,
|
||||
"pp_size": 2,
|
||||
"sp_size": 2,
|
||||
"num_microbatches": 2,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "all_to_all",
|
||||
"enable_flash_attention": True,
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 1,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 1,
|
||||
"pp_size": 1,
|
||||
|
@ -245,7 +247,6 @@ def run_llama_test(test_config):
|
|||
except Exception as e:
|
||||
print(f"Failed config: {test_config}")
|
||||
raise e
|
||||
|
||||
clear_layout_converter()
|
||||
Randomizer.reset_index()
|
||||
torch.cuda.empty_cache()
|
||||
|
|
Loading…
Reference in New Issue