[Feature] Enable PP + SP for llama (#5868)

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use a one cross entropy func for all shardformer models

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/5850/head
Edenzzzz 2024-07-09 18:05:20 +08:00 committed by GitHub
parent 66abf1c6e8
commit fbf33ecd01
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 148 additions and 323 deletions

View File

@ -1205,6 +1205,7 @@ class HybridParallelPlugin(PipelinePluginBase):
and self.enable_sequence_parallelism and self.enable_sequence_parallelism
and self.sequence_parallelism_mode == "all_to_all" and self.sequence_parallelism_mode == "all_to_all"
) )
# sync gradients across DP * SP ranks
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all": if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis]) dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
else: else:

View File

@ -3,7 +3,7 @@ from .attn import AttnMaskType, ColoAttention
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D
from .loss import cross_entropy_1d from .loss import cross_entropy_1d, dist_cross_entropy
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .parallel_module import ParallelModule from .parallel_module import ParallelModule
from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
@ -18,6 +18,7 @@ __all__ = [
"DropoutForParallelInput", "DropoutForParallelInput",
"DropoutForReplicatedInput", "DropoutForReplicatedInput",
"cross_entropy_1d", "cross_entropy_1d",
"dist_cross_entropy",
"BaseLayerNorm", "BaseLayerNorm",
"LayerNorm", "LayerNorm",
"RMSNorm", "RMSNorm",

View File

@ -2,8 +2,11 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from torch.autograd import Function from torch.autograd import Function
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn import CrossEntropyLoss
__all__ = ["DistCrossEntropy", "cross_entropy_1d"] from colossalai.shardformer.shard import ShardConfig
__all__ = ["DistCrossEntropy", "cross_entropy_1d", "dist_cross_entropy"]
class DistCrossEntropy(Function): class DistCrossEntropy(Function):
@ -132,3 +135,43 @@ def cross_entropy_1d(
dtype: torch.dtype = None, dtype: torch.dtype = None,
) -> torch.Tensor: ) -> torch.Tensor:
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype) return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype)
def dist_cross_entropy(
labels: torch.Tensor,
logits: torch.Tensor,
shard_config: ShardConfig,
out_features: int,
vocab_size: int,
dtype: torch.dtype,
) -> torch.Tensor:
"""
Helper to compute cross entropy loss for most shardformer models,
compatible with PP, TP and SP.
"""
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_labels = shift_labels.view(-1)
shift_labels = shift_labels.to(shift_logits.device)
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
# Cross entropy with all-reduce for TP
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=out_features,
dtype=dtype,
)
else:
# NOTE if use TP and not parallel_output, the output is gathered.
# see VocabParallelLMHead1D
shift_logits = shift_logits.view(-1, vocab_size)
loss = loss_fct(shift_logits, shift_labels)
return loss

View File

@ -28,7 +28,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig
from ..layer import cross_entropy_1d from ..layer import dist_cross_entropy
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -359,30 +359,14 @@ class BloomPipelineForwards:
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states).contiguous() lm_logits = self.lm_head(hidden_states).contiguous()
loss = None loss = dist_cross_entropy(
if labels is not None: labels,
# move labels to correct device to enable model parallelism lm_logits,
labels = labels.to(lm_logits.device) shard_config,
# Shift so that tokens < n predict n self.lm_head.out_features,
shift_logits = lm_logits[..., :-1, :].contiguous() self.config.vocab_size,
shift_labels = labels[..., 1:].contiguous() self.transformer.dtype,
batch_size, seq_length, vocab_size = shift_logits.shape
# Flatten the tokens
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
new_vocab_size = lm_logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
shift_labels = shift_labels.view(-1)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.transformer.dtype,
) )
else:
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
loss = loss_fct(shift_logits, shift_labels.view(-1))
if not return_dict: if not return_dict:
output = (lm_logits,) + transformer_outputs[1:] output = (lm_logits,) + transformer_outputs[1:]
@ -1040,24 +1024,10 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
loss = None loss = dist_cross_entropy(
if labels is not None: labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
new_vocab_size = lm_logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
shift_labels = shift_labels.view(-1)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.transformer.dtype,
) )
if not return_dict: if not return_dict:
output = (lm_logits,) + transformer_outputs[1:] output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output

View File

@ -5,7 +5,6 @@ from typing import List, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.cache_utils import Cache, DynamicCache from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.cohere.modeling_cohere import ( from transformers.models.cohere.modeling_cohere import (
@ -25,7 +24,7 @@ from colossalai.shardformer.layer._operation import (
) )
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig
from ..layer import ColoAttention, cross_entropy_1d from ..layer import ColoAttention, dist_cross_entropy
class CommandPipelineForwards: class CommandPipelineForwards:
@ -300,29 +299,9 @@ class CommandPipelineForwards:
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
logits = logits * self.logit_scale logits = logits * self.logit_scale
logits = logits.float() logits = logits.float()
loss = None loss = dist_cross_entropy(
if labels is not None: labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.model.dtype,
) )
else:
shift_logits = shift_logits.view(-1, self.config.vocab_size)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
@ -658,23 +637,13 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
logits = logits * self.logit_scale logits = logits * self.logit_scale
logits = logits.float() logits = logits.float()
loss = dist_cross_entropy(
loss = None labels,
if labels is not None: logits,
# Shift so that tokens < n predict n shard_config,
shift_logits = logits[..., :-1, :].contiguous() self.lm_head.out_features,
shift_labels = labels[..., 1:].contiguous() self.config.vocab_size,
shift_labels = shift_labels.view(-1) self.model.dtype,
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.model.dtype,
) )
if not return_dict: if not return_dict:

View File

@ -25,7 +25,7 @@ from colossalai.shardformer.layer import ColoAttention
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig
from ..layer import cross_entropy_1d from ..layer import dist_cross_entropy
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -372,27 +372,9 @@ class GPT2PipelineForwards:
hidden_states = outputs[0] hidden_states = outputs[0]
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
loss = None loss = dist_cross_entropy(
if labels is not None: labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
shift_labels = shift_labels.view(-1)
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.transformer.dtype,
) )
else:
loss = loss_fct(shift_logits, shift_labels)
if not return_dict: if not return_dict:
output = (lm_logits,) + outputs[1:] output = (lm_logits,) + outputs[1:]
@ -1282,23 +1264,8 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
loss = dist_cross_entropy(
loss = None labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
shift_labels = shift_labels.view(-1)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.transformer.dtype,
) )
if not return_dict: if not return_dict:

View File

@ -31,7 +31,7 @@ from colossalai.shardformer.layer._operation import (
) )
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig
from ..layer import ColoAttention, cross_entropy_1d from ..layer import ColoAttention, dist_cross_entropy
class LlamaPipelineForwards: class LlamaPipelineForwards:
@ -86,13 +86,20 @@ class LlamaPipelineForwards:
device = input_ids.device if input_ids is not None else inputs_embeds.device device = input_ids.device if input_ids is not None else inputs_embeds.device
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
input_shape = hidden_states.shape[:-1] input_shape = hidden_states.shape[:-1]
batch_size, seq_length = input_shape batch_size, seq_length = input_shape
device = hidden_states.device device = hidden_states.device
# Support SP + PP
sp_mode = shard_config.sequence_parallelism_mode
sp_group = shard_config.sequence_parallel_process_group
sp_size = shard_config.sequence_parallel_size
if sp_mode == "all_to_all" and not stage_manager.is_first_stage():
# For correct positions ids. The states will be gather along the seq dim in the attention layer later.
seq_length *= sp_size
past_seen_tokens = 0 past_seen_tokens = 0
if use_cache: # kept for BC (cache positions) if use_cache: # kept for BC (cache positions)
if not isinstance(past_key_values, StaticCache): if not isinstance(past_key_values, StaticCache):
@ -101,7 +108,7 @@ class LlamaPipelineForwards:
if cache_position is None: if cache_position is None:
if isinstance(past_key_values, StaticCache): if isinstance(past_key_values, StaticCache):
raise ValueError("cache_position is a required argument when using StaticCache.") raise ValueError("cache_position is a required argument when using StaticCache.")
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=device) cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=device)
seq_length_with_past = seq_length + past_seen_tokens seq_length_with_past = seq_length + past_seen_tokens
@ -118,7 +125,6 @@ class LlamaPipelineForwards:
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
# embed positions, for the first stage, hidden_states is the input embeddings, # embed positions, for the first stage, hidden_states is the input embeddings,
# for the other stages, hidden_states is the output of the previous stage # for the other stages, hidden_states is the output of the previous stage
if shard_config.enable_flash_attention: if shard_config.enable_flash_attention:
@ -134,6 +140,13 @@ class LlamaPipelineForwards:
else: else:
attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position) attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position)
# Support SP + PP
if stage_manager.is_first_stage():
if sp_mode in ["ring", "split_gather"]:
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group)
elif sp_mode == "all_to_all":
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size)
if self.gradient_checkpointing and self.training and use_cache: if self.gradient_checkpointing and self.training and use_cache:
if use_cache: if use_cache:
logger.warning_once( logger.warning_once(
@ -196,6 +209,10 @@ class LlamaPipelineForwards:
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
if sp_mode == "ring" or sp_mode == "split_gather":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
elif sp_mode == "all_to_all":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
# add hidden states from the last decoder layer # add hidden states from the last decoder layer
if output_hidden_states: if output_hidden_states:
@ -304,29 +321,9 @@ class LlamaPipelineForwards:
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
hidden_states = outputs[0] hidden_states = outputs[0]
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
loss = None loss = dist_cross_entropy(
if labels is not None: labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.model.dtype,
) )
else:
shift_logits = shift_logits.view(-1, self.config.vocab_size)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
@ -529,7 +526,6 @@ def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None,
) )
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, position_ids) cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
@ -804,24 +800,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
logits = logits.float() logits = logits.float()
loss = None loss = dist_cross_entropy(
if labels is not None: labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.model.dtype,
) )
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output return (loss,) + output if loss is not None else output

View File

@ -19,7 +19,7 @@ from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig
from ..layer import ColoAttention, cross_entropy_1d from ..layer import ColoAttention, dist_cross_entropy
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -275,29 +275,9 @@ class MistralForwards:
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
logits = logits.float() logits = logits.float()
loss = None loss = dist_cross_entropy(
if labels is not None: labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.model.dtype,
) )
else:
shift_logits = shift_logits.view(-1, self.config.vocab_size)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
@ -708,22 +688,8 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
logits = logits.float() logits = logits.float()
loss = None loss = dist_cross_entropy(
if labels is not None: labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.model.dtype,
) )
if not return_dict: if not return_dict:

View File

@ -22,7 +22,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import ColoAttention from colossalai.shardformer.layer import ColoAttention
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig
from ..layer import cross_entropy_1d from ..layer import dist_cross_entropy
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -330,30 +330,14 @@ class OPTPipelineForwards:
) )
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
logits = self.lm_head(outputs[0]).contiguous() logits = self.lm_head(outputs[0]).contiguous()
loss = None loss = dist_cross_entropy(
if labels is not None: labels,
# move labels to correct device to enable model parallelism logits,
labels = labels.to(logits.device) shard_config,
# Shift so that tokens < n predict n self.lm_head.out_features,
shift_logits = logits[..., :-1, :].contiguous() self.config.vocab_size,
shift_labels = labels[..., 1:].contiguous() self.model.decoder.dtype,
# Flatten the tokens
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
shift_labels = shift_labels.view(-1)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.model.decoder.dtype,
) )
else:
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
@ -971,25 +955,8 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
) )
logits = self.lm_head(outputs[0]).contiguous() logits = self.lm_head(outputs[0]).contiguous()
loss = dist_cross_entropy(
loss = None labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.decoder.dtype
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.model.decoder.dtype,
) )
if not return_dict: if not return_dict:

View File

@ -32,7 +32,7 @@ from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig
from ..layer import ColoAttention, cross_entropy_1d from ..layer import ColoAttention, dist_cross_entropy
class Qwen2PipelineForwards: class Qwen2PipelineForwards:
@ -317,25 +317,9 @@ class Qwen2PipelineForwards:
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
hidden_states = outputs[0] hidden_states = outputs[0]
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
loss = None loss = dist_cross_entropy(
if labels is not None: labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, logits.dtype
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
if shard_config.enable_tensor_parallelism:
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
) )
else:
shift_logits = shift_logits.view(-1, self.config.vocab_size)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
@ -737,26 +721,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
hidden_states = outputs[0] hidden_states = outputs[0]
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
logits = logits.float() logits = logits.float()
loss = dist_cross_entropy(
loss = None labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, logits.dtype
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
if shard_config.enable_tensor_parallelism:
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
) )
else:
shift_logits = shift_logits.view(-1, self.config.vocab_size)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]

View File

@ -1,4 +1,3 @@
import warnings
from functools import partial from functools import partial
from typing import Callable, Dict, List, Union from typing import Callable, Dict, List, Union
@ -66,13 +65,6 @@ class LlamaPolicy(Policy):
else: else:
norm_cls = RMSNorm norm_cls = RMSNorm
if self.pipeline_stage_manager is not None:
self.shard_config.enable_sequence_parallelism = False
self.shard_config.enable_sequence_overlap = False
self.shard_config.sequence_parallelism_mode = None
warnings.warn(
f"For llama, sequence parallelism is currently not compatible with pipeline parallelism, set to be False"
)
sp_mode = self.shard_config.sequence_parallelism_mode or None sp_mode = self.shard_config.sequence_parallelism_mode or None
sp_size = self.shard_config.sequence_parallel_size or None sp_size = self.shard_config.sequence_parallel_size or None
sp_group = self.shard_config.sequence_parallel_process_group or None sp_group = self.shard_config.sequence_parallel_process_group or None

View File

@ -59,10 +59,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
if ( if (
booster.plugin.zero_stage in [1, 2] booster.plugin.zero_stage in [1, 2]
and booster.plugin.shard_config.enable_sequence_parallelism and booster.plugin.shard_config.enable_sequence_parallelism
and booster.plugin.shard_config.pipeline_stage_manager is None
and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all" and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all"
): ):
master2working = sharded_optimizer.get_master_to_working_map()
for p1, p2 in zip(llama_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]): for p1, p2 in zip(llama_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]):
working_p = sharded_optimizer.master_to_working_param[id(p2)] working_p = master2working[id(p2)]
grads = sharded_optimizer.get_partitioned_gradients_by_param_id(0, id(working_p)) grads = sharded_optimizer.get_partitioned_gradients_by_param_id(0, id(working_p))
grad_index = ( grad_index = (
0 0
@ -146,6 +148,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
@parameterize( @parameterize(
"test_config", "test_config",
[ [
{ # Ulysess + Flash attention
"tp_size": 1,
"pp_size": 2,
"sp_size": 2,
"num_microbatches": 2,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
"enable_flash_attention": True,
"use_lazy_init": True,
"zero_stage": 0,
"precision": "fp16",
"initial_scale": 1,
},
{ # Test ring + Flash attention { # Test ring + Flash attention
"tp_size": 2, "tp_size": 2,
"pp_size": 1, "pp_size": 1,
@ -159,19 +174,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"precision": "fp16", "precision": "fp16",
"initial_scale": 1, "initial_scale": 1,
}, },
{ # Ulysess + Flash attention
"tp_size": 1,
"pp_size": 2,
"sp_size": 2,
"num_microbatches": 2,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
"enable_flash_attention": True,
"use_lazy_init": True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
{ {
"tp_size": 1, "tp_size": 1,
"pp_size": 1, "pp_size": 1,
@ -245,7 +247,6 @@ def run_llama_test(test_config):
except Exception as e: except Exception as e:
print(f"Failed config: {test_config}") print(f"Failed config: {test_config}")
raise e raise e
clear_layout_converter() clear_layout_converter()
Randomizer.reset_index() Randomizer.reset_index()
torch.cuda.empty_cache() torch.cuda.empty_cache()