[Feature] Split cross-entropy computation in SP (#5959)

* halfway

* fix cross-PP-stage position id length diff bug

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unified cross entropy func for all shardformer models

* remove redundant lines

* add basic ring attn; debug cross entropy

* fwd bwd logic complete

* fwd bwd logic complete; add experimental triton rescale

* precision tests passed

* precision tests passed

* fix typos and remove misc files

* update softmax_lse shape by new interface

* change tester name

* remove buffer clone; support packed seq layout

* add varlen tests

* fix typo

* all tests passed

* add dkv_group; fix mask

* remove debug statements

* adapt chatglm, command-R, qwen

* debug

* halfway

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unified cross entropy func for all shardformer models

* remove redundant lines

* add basic ring attn; debug cross entropy

* fwd bwd logic complete

* fwd bwd logic complete; add experimental triton rescale

* precision tests passed

* precision tests passed

* fix typos and remove misc files

* add sp_mode to benchmark; fix varlen interface

* update softmax_lse shape by new interface

* add varlen tests

* fix typo

* all tests passed

* add dkv_group; fix mask

* remove debug statements

* add comments

* q1 index only once

* remove events to simplify stream sync

* simplify forward/backward logic

* 2d ring forward passed

* 2d ring backward passed

* fixes

* fix ring attn loss

* 2D ring backward + llama passed

* merge

* update logger

* fix typo

* rebase

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix typo

* remove typos

* fixes

* support GPT

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/6053/head
Wenxuan Tan 3 months ago committed by GitHub
parent b3db1058ec
commit 8fd25d6e09
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1097,13 +1097,19 @@ def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1, fp8
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim, fp8_communication) return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim, fp8_communication)
def gather_sp_output(hidden_states, sp_group, sp_mode, fp8_communication=False): def gather_sp_output(hidden_states, shard_config, sp_dim=1):
""" """
Gather the output of the last layer for cross entropy computation Gather the output of the last layer for cross entropy computation
""" """
sp_group = shard_config.sequence_parallel_process_group
sp_mode = shard_config.sequence_parallelism_mode
fp8_comm = shard_config.fp8_communication
if dist.get_world_size(sp_group) == 1:
return hidden_states
# Rescale grad (HybridParallelPlugin applies ZeRO grad averaging on the DP * SP group) # Rescale grad (HybridParallelPlugin applies ZeRO grad averaging on the DP * SP group)
scale = None if is_share_sp_tp(sp_mode) else dist.get_world_size(sp_group) scale = None if is_share_sp_tp(sp_mode) else dist.get_world_size(sp_group)
hidden_states = gather_forward_split_backward( hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, grad_scale=scale, fp8_communication=fp8_communication hidden_states, sp_dim, sp_group, grad_scale=scale, fp8_communication=fp8_comm
) )
return hidden_states return hidden_states

@ -433,7 +433,6 @@ class RingAttention(torch.autograd.Function):
assert ( assert (
sp_size % inner_ring_size == 0 sp_size % inner_ring_size == 0
), f"sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}" ), f"sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}"
logger = get_dist_logger() logger = get_dist_logger()
logger.info( logger.info(
f"Using 2D Ring Attention with inner ring size {inner_ring_size} to maximze NIC util for inter-node comm. Cross your fingers for speed-ups!", f"Using 2D Ring Attention with inner ring size {inner_ring_size} to maximze NIC util for inter-node comm. Cross your fingers for speed-ups!",
@ -898,6 +897,7 @@ class RingAttention(torch.autograd.Function):
local_sp_rank = dist.get_rank(sp_group) local_sp_rank = dist.get_rank(sp_group)
sp_size = dist.get_world_size(sp_group) sp_size = dist.get_world_size(sp_group)
# Using separate streams (pg) for concurrent kv and dkv comm may # Using separate streams (pg) for concurrent kv and dkv comm may
# cause NCCL "software caused connection abort" here... # cause NCCL "software caused connection abort" here...
local_kv_comm = RingComm(local_kv_group) local_kv_comm = RingComm(local_kv_group)
@ -1119,9 +1119,14 @@ class RingAttention(torch.autograd.Function):
the batch dim to a packed 1d sequence. Contingent on model forward shape definitions. the batch dim to a packed 1d sequence. Contingent on model forward shape definitions.
Returns: Returns:
inputs_embeds: Packed input embeddings of shape [B, Sq // sp_size, ...]. torch.Tensor:
mask_info: A dictionary of mask info. Packed input embeddings of shape [B, Sq // sp_size, ...].
position_ids: Packed position ids of shape [..., Sq // sp_size].
Dict[str, Any]:
A dictionary containing mask info.
torch.Tensor:
Packed position ids of shape [..., Sq // sp_size].
""" """
_load_varlen_helpers() _load_varlen_helpers()

@ -153,7 +153,6 @@ def dist_cross_entropy(
labels: torch.Tensor, # [B, S] or [B, S, Vocab_size] labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
logits: torch.Tensor, # [B, S, Vocab_size] logits: torch.Tensor, # [B, S, Vocab_size]
shard_config: ShardConfig, shard_config: ShardConfig,
out_features: int,
vocab_size: int, vocab_size: int,
dtype: torch.dtype, dtype: torch.dtype,
seq_dim: int = 1, seq_dim: int = 1,
@ -226,13 +225,13 @@ def dist_cross_entropy(
logits, logits,
labels, labels,
process_group=shard_config.tensor_parallel_process_group, process_group=shard_config.tensor_parallel_process_group,
vocab_size=out_features, vocab_size=vocab_size,
dtype=dtype, dtype=dtype,
mode="sum", mode="sum",
) )
else: else:
# NOTE if use TP and not parallel_output, the output is gathered in VocabParallelLMHead1D # NOTE if use TP and not parallel_output, the output is gathered in VocabParallelLMHead1D
logits = logits.view(-1, vocab_size) logits = logits.view(-1, logits.size(-1))
loss = loss_fct(logits, labels) loss = loss_fct(logits, labels)
# Reduce loss instead of gathering logits over seq dim for savings # Reduce loss instead of gathering logits over seq dim for savings

@ -313,19 +313,19 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
# Matrix multiply. # Matrix multiply.
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
if self.seq_parallel_mode == "split_gather":
if self.seq_parallel_mode is None: input_parallel = input_
# Set up backprop all-reduce. output_parallel = matmul_gather_forward_reducescatter_backward(
input_parallel = reduce_backward(input_, self.process_group, fp8_communication=self.fp8_communication)
output_parallel = matmul_with_async_comm(
input_parallel, input_parallel,
self.weight, self.weight,
bias, bias,
self.process_group, self.process_group,
self.async_communication, True,
1,
self.overlap,
fp8_communication=self.fp8_communication, fp8_communication=self.fp8_communication,
) )
elif self.seq_parallel_mode == "split_gather": elif self.seq_parallel_mode == "ring":
input_parallel = input_ input_parallel = input_
output_parallel = matmul_gather_forward_reducescatter_backward( output_parallel = matmul_gather_forward_reducescatter_backward(
input_parallel, input_parallel,
@ -335,13 +335,22 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
True, True,
1, 1,
self.overlap, self.overlap,
True,
fp8_communication=self.fp8_communication, fp8_communication=self.fp8_communication,
) )
elif self.seq_parallel_mode == "ring": elif self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn":
input_parallel = input_ # Set up backprop all-reduce.
output_parallel = matmul_gather_forward_reducescatter_backward( input_parallel = reduce_backward(input_, self.process_group)
input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap, True output_parallel = matmul_with_async_comm(
input_parallel,
self.weight,
bias,
self.process_group,
self.async_communication,
fp8_communication=self.fp8_communication,
) )
else:
raise NotImplementedError(f"seq_parallel_mode={self.seq_parallel_mode} is not supported!")
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
@ -553,7 +562,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
handle.wait() handle.wait()
output = torch.cat(output_parallel_list, dim=-1) output = torch.cat(output_parallel_list, dim=-1)
else: else:
if self.seq_parallel_mode is None: if self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn":
output_parallel = torch.matmul(input_, self.weight) output_parallel = torch.matmul(input_, self.weight)
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication) output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
elif self.seq_parallel_mode == "split_gather": elif self.seq_parallel_mode == "split_gather":
@ -567,8 +576,12 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
elif self.seq_parallel_mode == "ring": elif self.seq_parallel_mode == "ring":
output_parallel = torch.matmul(input_, self.weight) output_parallel = torch.matmul(input_, self.weight)
output = reducescatter_forward_gather_backward( output = reducescatter_forward_gather_backward(
output_parallel, self.process_group, 1, self.fp8_communication output_parallel,
self.process_group,
1,
) )
else:
raise NotImplementedError(f"seq_parallel_mode={self.seq_parallel_mode} is not supported!")
if not self.skip_bias_add: if not self.skip_bias_add:
if self.bias is not None: if self.bias is not None:

@ -309,6 +309,9 @@ def split_batch_zigzag(
""" """
sp_size = dist.get_world_size(sp_group) sp_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group) sp_rank = dist.get_rank(sp_group)
if sp_size == 1:
return batch
if isinstance(batch, torch.Tensor): if isinstance(batch, torch.Tensor):
batch = [batch] batch = [batch]
seq_dim = seq_dim if seq_dim != -1 else batch[0].dim() - 1 seq_dim = seq_dim if seq_dim != -1 else batch[0].dim() - 1
@ -364,6 +367,9 @@ def split_varlen_zigzag(
""" """
sp_size = dist.get_world_size(sp_group) sp_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group) sp_rank = dist.get_rank(sp_group)
if sp_size == 1:
return batch
if is_2d: if is_2d:
assert max_seqlen > 0, "max_seqlen must be provided for 2D input" assert max_seqlen > 0, "max_seqlen must be provided for 2D input"

@ -365,12 +365,13 @@ class BloomPipelineForwards:
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states).contiguous() lm_logits = self.lm_head(hidden_states).contiguous()
loss = None
if labels is not None:
loss = dist_cross_entropy( loss = dist_cross_entropy(
labels, labels,
lm_logits, lm_logits,
shard_config, shard_config,
self.lm_head.out_features, self.lm_head.out_features,
self.config.vocab_size,
self.transformer.dtype, self.transformer.dtype,
) )
@ -1036,8 +1037,10 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
loss = dist_cross_entropy( loss = dist_cross_entropy(
labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype labels, lm_logits, shard_config, self.lm_head.out_features, self.transformer.dtype
) )
if not return_dict: if not return_dict:

@ -4,7 +4,6 @@ from typing import List, Optional, Tuple
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.utils import logging from transformers.utils import logging
@ -13,10 +12,13 @@ from colossalai.shardformer import ShardConfig
from colossalai.shardformer.layer import AttnMaskType, ColoAttention from colossalai.shardformer.layer import AttnMaskType, ColoAttention
from colossalai.shardformer.layer._operation import ( from colossalai.shardformer.layer._operation import (
all_to_all_comm, all_to_all_comm,
gather_forward_split_backward, gather_sp_output,
is_share_sp_tp,
split_forward_gather_backward, split_forward_gather_backward,
) )
from ..layer import dist_cross_entropy
def get_flash_core_attention_forward(): def get_flash_core_attention_forward():
from .chatglm2_6b.modeling_chatglm import CoreAttention from .chatglm2_6b.modeling_chatglm import CoreAttention
@ -138,6 +140,7 @@ class ChatGLMPipelineForwards:
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None, shard_config: ShardConfig = None,
force_sp_output_gather: Optional[bool] = True,
): ):
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
output_hidden_states = ( output_hidden_states = (
@ -180,6 +183,15 @@ class ChatGLMPipelineForwards:
if full_attention_mask is None: if full_attention_mask is None:
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
# Support SP + PP
sp_size = shard_config.sequence_parallel_size
sp_mode = shard_config.sequence_parallelism_mode
sp_group = shard_config.sequence_parallel_process_group
# For generating full positions ids (the states will be gathered along the seq dim before attention fwd).
if sp_mode != "ring_attn" and not stage_manager.is_first_stage():
seq_length *= sp_size
# Rotary positional embeddings # Rotary positional embeddings
rotary_pos_emb = self.rotary_pos_emb(self.seq_length) rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
if position_ids is not None: if position_ids is not None:
@ -200,21 +212,14 @@ class ChatGLMPipelineForwards:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
start_idx, end_idx = stage_index[0], stage_index[1] start_idx, end_idx = stage_index[0], stage_index[1]
if shard_config and shard_config.enable_sequence_parallelism: # Keep the input split across all PP stages
if shard_config.sequence_parallelism_mode == "split_gather": if stage_manager.is_first_stage():
hidden_states = split_forward_gather_backward( if shard_config.enable_sequence_parallelism:
hidden_states, if sp_mode == "split_gather":
dim=0,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = split_forward_gather_backward( hidden_states = split_forward_gather_backward(
hidden_states, hidden_states,
dim=0, dim=0,
process_group=shard_config.sequence_parallel_process_group, process_group=sp_group,
grad_scale=1 / shard_config.sequence_parallel_size,
fp8_communication=shard_config.fp8_communication,
) )
elif shard_config.sequence_parallelism_mode == "all_to_all": elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = split_forward_gather_backward( hidden_states = split_forward_gather_backward(
@ -223,6 +228,7 @@ class ChatGLMPipelineForwards:
process_group=shard_config.sequence_parallel_process_group, process_group=shard_config.sequence_parallel_process_group,
grad_scale=1 / shard_config.sequence_parallel_size, grad_scale=1 / shard_config.sequence_parallel_size,
) )
for idx in range(start_idx, end_idx): for idx in range(start_idx, end_idx):
layer = self.encoder._get_layer(idx) layer = self.encoder._get_layer(idx)
if output_hidden_states: if output_hidden_states:
@ -248,35 +254,19 @@ class ChatGLMPipelineForwards:
if use_cache: if use_cache:
presents = presents + (kv_cache,) presents = presents + (kv_cache,)
if shard_config and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode == "split_gather":
hidden_states = gather_forward_split_backward(
hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = gather_forward_split_backward(
hidden_states,
dim=0,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=shard_config.sequence_parallel_size,
fp8_communication=shard_config.fp8_communication,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = gather_forward_split_backward(
hidden_states,
dim=0,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=shard_config.sequence_parallel_size,
)
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
# final layer_norm # final layer_norm
if self.encoder.post_layer_norm: if self.encoder.post_layer_norm:
hidden_states = self.encoder.final_layernorm(hidden_states) hidden_states = self.encoder.final_layernorm(hidden_states)
# Gather seq-wise in the final output stage
if shard_config.enable_sequence_parallelism:
sp_mode = shard_config.sequence_parallelism_mode
if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode):
hidden_states = gather_sp_output(hidden_states, shard_config, sp_dim=0)
if not return_dict: if not return_dict:
return tuple( return tuple(
v v
@ -333,6 +323,7 @@ class ChatGLMPipelineForwards:
hidden_states=hidden_states, hidden_states=hidden_states,
stage_index=stage_index, stage_index=stage_index,
shard_config=shard_config, shard_config=shard_config,
force_sp_output_gather=False,
) )
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
@ -340,17 +331,21 @@ class ChatGLMPipelineForwards:
hidden_states = hidden_states[-1:] hidden_states = hidden_states[-1:]
lm_logits = self.transformer.output_layer(hidden_states) lm_logits = self.transformer.output_layer(hidden_states)
lm_logits = lm_logits.transpose(0, 1).contiguous() lm_logits = lm_logits.transpose(0, 1).contiguous()
loss = None loss = None
if labels is not None: if labels is not None:
lm_logits = lm_logits.to(torch.float32) # ChatGLM doesn't have lm_head split
# Shift so that tokens < n predict n enable_tp = shard_config.enable_tensor_parallelism
shift_logits = lm_logits[..., :-1, :].contiguous() shard_config.enable_tensor_parallelism = False
shift_labels = labels[..., 1:].contiguous() loss = dist_cross_entropy(
# Flatten the tokens labels,
loss_fct = CrossEntropyLoss(ignore_index=-100) lm_logits,
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) shard_config,
lm_logits = lm_logits.to(hidden_states.dtype) self.transformer.output_layer.out_features,
loss = loss.to(hidden_states.dtype) lm_logits.dtype,
)
shard_config.enable_tensor_parallelism = enable_tp
if not return_dict: if not return_dict:
output = (lm_logits,) + transformer_outputs[1:] output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
@ -379,6 +374,7 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
force_sp_output_gather: Optional[bool] = True,
): ):
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -456,22 +452,9 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
use_cache=use_cache, use_cache=use_cache,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
) )
if shard_config.enable_sequence_parallelism:
if sp_mode in ["split_gather"]: if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode):
hidden_states = gather_forward_split_backward( hidden_states = gather_sp_output(hidden_states, shard_config, sp_dim=0)
hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
elif sp_mode == "all_to_all":
hidden_states = gather_forward_split_backward(
hidden_states,
dim=0,
process_group=sp_group,
grad_scale=sp_size,
fp8_communication=shard_config.fp8_communication,
)
if not return_dict: if not return_dict:
return tuple( return tuple(

@ -17,14 +17,13 @@ from transformers.models.cohere.modeling_cohere import (
from transformers.utils import logging from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer._operation import ( from colossalai.shardformer.layer._operation import all_to_all_comm, split_forward_gather_backward
all_to_all_comm,
gather_forward_split_backward,
split_forward_gather_backward,
)
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig
from ..layer import ColoAttention, dist_cross_entropy from ..layer import ColoAttention, dist_cross_entropy
from ..layer._operation import gather_sp_output, is_share_sp_tp
_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring"]
_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"] _SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]
@ -52,6 +51,7 @@ class CommandPipelineForwards:
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None, shard_config: ShardConfig = None,
force_sp_output_gather: bool = True,
): ):
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -93,10 +93,16 @@ class CommandPipelineForwards:
if not isinstance(past_key_values, StaticCache): if not isinstance(past_key_values, StaticCache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length() past_seen_tokens = past_key_values.get_seq_length()
# NOTE: For generating full positions ids
# (the states will be gathered along the seq dim before attention fwd).
if shard_config.sequence_parallelism_mode != "ring_attn" and not stage_manager.is_first_stage():
seq_length *= shard_config.sequence_parallel_size
if cache_position is None: if cache_position is None:
if isinstance(past_key_values, StaticCache): if isinstance(past_key_values, StaticCache):
raise ValueError("cache_position is a required argument when using StaticCache.") raise ValueError("cache_position is a required argument when using StaticCache.")
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=device) cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=device)
seq_length_with_past = seq_length + past_seen_tokens seq_length_with_past = seq_length + past_seen_tokens
@ -136,7 +142,7 @@ class CommandPipelineForwards:
) )
use_cache = False use_cache = False
if shard_config and shard_config.enable_sequence_parallelism: if stage_manager.is_first_stage() and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]: if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
hidden_states = split_forward_gather_backward( hidden_states = split_forward_gather_backward(
hidden_states, hidden_states,
@ -208,23 +214,10 @@ class CommandPipelineForwards:
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
sp_mode = shard_config.sequence_parallelism_mode
if shard_config and shard_config.enable_sequence_parallelism: if shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]: if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode):
hidden_states = gather_forward_split_backward( hidden_states = gather_sp_output(hidden_states, shard_config)
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = gather_forward_split_backward(
hidden_states,
dim=1,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=shard_config.sequence_parallel_size,
fp8_communication=shard_config.fp8_communication,
)
# add hidden states from the last decoder layer # add hidden states from the last decoder layer
if output_hidden_states: if output_hidden_states:
@ -327,6 +320,7 @@ class CommandPipelineForwards:
hidden_states=hidden_states, hidden_states=hidden_states,
stage_index=stage_index, stage_index=stage_index,
shard_config=shard_config, shard_config=shard_config,
force_sp_output_gather=False,
) )
past_key_values = None past_key_values = None
@ -335,9 +329,10 @@ class CommandPipelineForwards:
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
logits = logits * self.logit_scale logits = logits * self.logit_scale
logits = logits.float() logits = logits.float()
loss = dist_cross_entropy(
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype loss = None
) if labels is not None:
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype)
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
@ -482,6 +477,7 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
force_sp_output_gather: bool = True,
) -> Union[Tuple, BaseModelOutputWithPast]: ) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
@ -584,14 +580,10 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
if sp_mode == "ring" or sp_mode == "split_gather": # Cases that don't support parallelizing cross entropy computation along sequence
hidden_states = gather_forward_split_backward( if shard_config.enable_sequence_parallelism:
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication if (not shard_config.parallel_output) or is_share_sp_tp(sp_mode) or force_sp_output_gather:
) hidden_states = gather_sp_output(hidden_states, shard_config)
elif sp_mode == "all_to_all":
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
)
# add hidden states from the last decoder layer # add hidden states from the last decoder layer
if output_hidden_states: if output_hidden_states:
@ -676,6 +668,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
cache_position=cache_position, cache_position=cache_position,
force_sp_output_gather=False,
) )
hidden_states = outputs[0] hidden_states = outputs[0]
@ -683,12 +676,14 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
logits = logits * self.logit_scale logits = logits * self.logit_scale
logits = logits.float() logits = logits.float()
loss = None
if labels is not None:
loss = dist_cross_entropy( loss = dist_cross_entropy(
labels, labels,
logits, logits,
shard_config, shard_config,
self.lm_head.out_features, self.lm_head.out_features,
self.config.vocab_size,
self.model.dtype, self.model.dtype,
) )

@ -21,8 +21,9 @@ from transformers.models.gpt2.modeling_gpt2 import (
from transformers.utils import logging from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import ColoAttention from colossalai.shardformer.layer import ColoAttention, RingAttention
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.layer._operation import gather_sp_output, split_forward_gather_backward
from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig
from ..layer import dist_cross_entropy from ..layer import dist_cross_entropy
@ -39,10 +40,16 @@ def _get_attention_mask(
encoder_hidden_states: Optional[torch.Tensor], encoder_hidden_states: Optional[torch.Tensor],
encoder_attention_mask: Optional[torch.FloatTensor], encoder_attention_mask: Optional[torch.FloatTensor],
) -> Tuple[Optional[Union[torch.Tensor, dict]], Optional[Union[torch.Tensor, dict]]]: ) -> Tuple[Optional[Union[torch.Tensor, dict]], Optional[Union[torch.Tensor, dict]]]:
batch_size, seq_len = hidden_states.shape[:2] # Received input is already split for non-first pipeline stages,
# but attn mask isn't
batch_size = hidden_states.size(0)
seq_len = attention_mask.size(-1)
sp_mode = shard_config.sequence_parallelism_mode
# If a 2D or 3D attention mask is provided for the cross-attention # If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.add_cross_attention and encoder_hidden_states is not None: if self.config.add_cross_attention and encoder_hidden_states is not None:
assert not sp_mode == "ring_attn", "Ring Attention only supports decoder-only."
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
if shard_config.enable_flash_attention: if shard_config.enable_flash_attention:
encoder_attention_mask = ColoAttention.prepare_attn_kwargs( encoder_attention_mask = ColoAttention.prepare_attn_kwargs(
@ -62,6 +69,7 @@ def _get_attention_mask(
encoder_attention_mask = {"attention_mask": None} encoder_attention_mask = {"attention_mask": None}
else: else:
encoder_attention_mask = None encoder_attention_mask = None
# GPT2Attention mask. # GPT2Attention mask.
past_key_values_length = 0 past_key_values_length = 0
if past_key_values is not None and past_key_values[0] is not None: if past_key_values is not None and past_key_values[0] is not None:
@ -69,6 +77,7 @@ def _get_attention_mask(
if shard_config.enable_flash_attention: if shard_config.enable_flash_attention:
if attention_mask is not None: if attention_mask is not None:
attention_mask = attention_mask.view(batch_size, -1) attention_mask = attention_mask.view(batch_size, -1)
attention_mask = ColoAttention.prepare_attn_kwargs( attention_mask = ColoAttention.prepare_attn_kwargs(
(batch_size, 1, seq_len, seq_len + past_key_values_length), (batch_size, 1, seq_len, seq_len + past_key_values_length),
hidden_states.dtype, hidden_states.dtype,
@ -123,6 +132,7 @@ class GPT2PipelineForwards:
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None, shard_config: ShardConfig = None,
force_sp_gather: Optional[bool] = True,
) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]: ) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]:
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward. # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward.
# Please refer to original code of transformers for more details. # Please refer to original code of transformers for more details.
@ -146,16 +156,15 @@ class GPT2PipelineForwards:
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
use_cache = False use_cache = False
if stage_manager.is_first_stage(): disable_pp = stage_manager is None
if disable_pp or stage_manager.is_first_stage():
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
input_shape = input_ids.size() input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1]) input_ids = input_ids.view(-1, input_shape[-1])
input_ids.shape[0]
elif inputs_embeds is not None: elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1] input_shape = inputs_embeds.size()[:-1]
inputs_embeds.shape[0]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
@ -176,7 +185,7 @@ class GPT2PipelineForwards:
# head_mask has shape n_layer x batch x n_heads x N x N # head_mask has shape n_layer x batch x n_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer) head_mask = self.get_head_mask(head_mask, self.config.n_layer)
if stage_manager.is_first_stage(): if disable_pp or stage_manager.is_first_stage():
if position_ids is None: if position_ids is None:
position_ids = torch.arange(0, input_shape[-1], dtype=torch.long, device=device) position_ids = torch.arange(0, input_shape[-1], dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0) position_ids = position_ids.unsqueeze(0)
@ -190,9 +199,7 @@ class GPT2PipelineForwards:
hidden_states = hidden_states + token_type_embeds hidden_states = hidden_states + token_type_embeds
hidden_states = self.drop(hidden_states) hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),) attn_kwargs, encoder_attention_mask = _get_attention_mask(
attention_mask, encoder_attention_mask = _get_attention_mask(
self, self,
shard_config, shard_config,
hidden_states, hidden_states,
@ -215,23 +222,43 @@ class GPT2PipelineForwards:
# split the input tensor along sequence dimension # split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
if shard_config and shard_config.enable_sequence_parallelism: sp_mode = shard_config.sequence_parallelism_mode
if shard_config.sequence_parallelism_mode == "split_gather": sp_group = shard_config.sequence_parallel_process_group
if disable_pp or stage_manager.is_first_stage():
# Ring Attention's special zigzag batch processing
if sp_mode == "ring_attn":
assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention."
if not attention_mask.bool().all():
hidden_states, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch(
attention_mask, sp_group, hidden_states, position_ids
)
else:
hidden_states, position_ids = split_batch_zigzag([hidden_states, position_ids], sp_group)
# Other sp modes
else:
if sp_mode == "split_gather":
hidden_states = split_forward_gather_backward( hidden_states = split_forward_gather_backward(
hidden_states, hidden_states,
dim=1, dim=1,
process_group=shard_config.tensor_parallel_process_group, process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
) )
elif sp_mode == "ring_attn":
# Later stages already received split hidden states
_, attn_kwargs, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group)
del attention_mask
# Going through held blocks. # Going through held blocks.
if disable_pp:
start_idx, end_idx = 0, len(self.h)
else:
start_idx, end_idx = stage_index[0], stage_index[1] start_idx, end_idx = stage_index[0], stage_index[1]
for i in range(start_idx, end_idx): for i in range(start_idx, end_idx):
block = self.h[i] block = self.h[i]
torch.cuda.set_device(hidden_states.device) torch.cuda.set_device(hidden_states.device)
# Ensure that attention_mask is always on the same device as hidden_states # Ensure that attention_mask is always on the same device as hidden_states
if torch.is_tensor(attention_mask): if torch.is_tensor(attn_kwargs):
attention_mask = attention_mask.to(hidden_states.device) attn_kwargs = attn_kwargs.to(hidden_states.device)
if isinstance(head_mask, torch.Tensor): if isinstance(head_mask, torch.Tensor):
head_mask = head_mask.to(hidden_states.device) head_mask = head_mask.to(hidden_states.device)
if output_hidden_states: if output_hidden_states:
@ -242,7 +269,7 @@ class GPT2PipelineForwards:
block.__call__, block.__call__,
hidden_states, hidden_states,
None, None,
attention_mask, attn_kwargs,
head_mask[i], head_mask[i],
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
@ -253,7 +280,7 @@ class GPT2PipelineForwards:
outputs = block( outputs = block(
hidden_states, hidden_states,
layer_past=None, layer_past=None,
attention_mask=attention_mask, attention_mask=attn_kwargs,
head_mask=head_mask[i], head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
@ -270,26 +297,25 @@ class GPT2PipelineForwards:
if self.config.add_cross_attention: if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
# When sequence parallelism done, gather the output tensor in forward and split it in backward # When sequence parallelism is done, gather the output tensor in forward and split it in backward
if shard_config and shard_config.enable_sequence_parallelism: gather_output = (not shard_config.parallel_output) or force_sp_gather or is_share_sp_tp(sp_mode)
if shard_config.sequence_parallelism_mode == "split_gather": if disable_pp or stage_manager.is_last_stage():
hidden_states = gather_forward_split_backward( if gather_output:
hidden_states, hidden_states = gather_sp_output(hidden_states, shard_config)
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
if stage_manager.is_last_stage(): # gather_sp_output could've changed seq length.
hidden_states = self.ln_f(hidden_states) input_shape = (*input_shape[:-1], hidden_states.size(-2))
output_shape = input_shape + (hidden_states.size(-1),)
if disable_pp or stage_manager.is_last_stage():
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape) hidden_states = hidden_states.view(output_shape)
# Add last hidden state # Add last hidden state
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if stage_manager.is_last_stage(): if disable_pp or stage_manager.is_last_stage():
if not return_dict: if not return_dict:
return tuple( return tuple(
v v
@ -366,16 +392,28 @@ class GPT2PipelineForwards:
hidden_states=hidden_states, hidden_states=hidden_states,
stage_index=stage_index, stage_index=stage_index,
shard_config=shard_config, shard_config=shard_config,
force_sp_gather=False,
) )
# If not at the last stage, return hidden_states as in GPT2Model # If not at the last stage, return hidden_states as in GPT2Model
if not stage_manager.is_last_stage(): disable_pp = stage_manager is None
if (not disable_pp) and (not stage_manager.is_last_stage()):
return {"hidden_states": outputs["hidden_states"]} return {"hidden_states": outputs["hidden_states"]}
hidden_states = outputs[0] hidden_states = outputs[0]
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
if shard_config.sequence_parallelism_mode == "ring_attn":
# Split labels in a zigzag fashion too
sp_group = shard_config.sequence_parallel_process_group
if not attention_mask.bool().all():
# [B, max_seqlen // sp_size]
labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True)
else:
labels = split_batch_zigzag(labels, sp_group, seq_dim=1, is_label=True)
if labels is not None:
loss = dist_cross_entropy( loss = dist_cross_entropy(
labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype labels, lm_logits, shard_config, self.lm_head.out_features, self.transformer.dtype
) )
if not return_dict: if not return_dict:
@ -770,7 +808,7 @@ class GPT2PipelineForwards:
) )
def get_gpt2_flash_attention_forward(): def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None):
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
def forward( def forward(
@ -817,6 +855,21 @@ def get_gpt2_flash_attention_forward():
if self.scale_attn_by_inverse_layer_idx: if self.scale_attn_by_inverse_layer_idx:
scale /= float(self.layer_idx + 1) scale /= float(self.layer_idx + 1)
dropout_p = self.attn_dropout.p if self.training else 0.0 dropout_p = self.attn_dropout.p if self.training else 0.0
sp_mode = shard_config.sequence_parallelism_mode
sp_group = shard_config.sequence_parallel_process_group
if sp_mode == "ring_attn":
attn_output = RingAttention.attention(
query,
key,
value,
sp_group,
**attention_mask,
dropout_p=dropout_p,
scale=scale,
inner_ring_size=shard_config.inner_ring_size,
)
else:
attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale) attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale)
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.c_proj(attn_output) attn_output = self.c_proj(attn_output)
@ -828,466 +881,6 @@ def get_gpt2_flash_attention_forward():
return forward return forward
def get_gpt_model_forward_for_flash_attn(shard_config: ShardConfig):
def forward(
self: GPT2Model,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
input_ids.shape[0]
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
inputs_embeds.shape[0]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(
past_length,
input_shape[-1] + past_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
if token_type_ids is not None:
token_type_embeds = self.wte(token_type_ids)
hidden_states = hidden_states + token_type_embeds
hidden_states = self.drop(hidden_states)
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
attention_mask, encoder_attention_mask = _get_attention_mask(
self,
shard_config,
hidden_states,
past_key_values,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
# Model parallel
if self.model_parallel:
torch.cuda.set_device(hidden_states.device)
# Ensure layer_past is on same device as hidden_states (might not be correct)
if layer_past is not None:
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
# Ensure that attention_mask is always on the same device as hidden_states
if torch.is_tensor(attention_mask):
attention_mask = attention_mask.to(hidden_states.device)
if isinstance(head_mask, torch.Tensor):
head_mask = head_mask.to(hidden_states.device)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
None,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
# Model Parallel: If it's the last layer for that device, put things on the next device
if self.model_parallel:
for k, v in self.device_map.items():
if i == v[-1] and "cuda:" + str(k) != self.last_device:
hidden_states = hidden_states.to("cuda:" + str(k + 1))
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
presents,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
return forward
def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
input_ids.shape[0]
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
inputs_embeds.shape[0]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(
past_length,
input_shape[-1] + past_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
if token_type_ids is not None:
token_type_embeds = self.wte(token_type_ids)
hidden_states = hidden_states + token_type_embeds
hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
attention_mask, encoder_attention_mask = _get_attention_mask(
self,
shard_config,
hidden_states,
past_key_values,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
)
if self.gradient_checkpointing and self.training:
if use_cache:
logger = logging.get_logger(__name__)
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
all_hidden_states = () if output_hidden_states else None
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
hidden_states = split_forward_gather_backward(
hidden_states,
dim=1,
process_group=shard_config.sequence_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
# Model parallel
if self.model_parallel:
torch.cuda.set_device(hidden_states.device)
# Ensure layer_past is on same device as hidden_states (might not be correct)
if layer_past is not None:
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
# Ensure that attention_mask is always on the same device as hidden_states
if torch.is_tensor(attention_mask):
attention_mask = attention_mask.to(hidden_states.device)
if isinstance(head_mask, torch.Tensor):
head_mask = head_mask.to(hidden_states.device)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
None,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
# Model Parallel: If it's the last layer for that device, put things on the next device
if self.model_parallel:
for k, v in self.device_map.items():
if i == v[-1] and "cuda:" + str(k) != self.last_device:
hidden_states = hidden_states.to("cuda:" + str(k + 1))
# When sequence parallelism done, gather the output tensor in forward and split it in backward
hidden_states = gather_forward_split_backward(
hidden_states,
dim=1,
process_group=shard_config.sequence_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
presents,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
return forward
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
from transformers import GPT2LMHeadModel
def forward(
self: GPT2LMHeadModel,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
loss = dist_cross_entropy(
labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype
)
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
cross_attentions=transformer_outputs.cross_attentions,
)
return forward
def get_jit_fused_gpt2_mlp_forward(): def get_jit_fused_gpt2_mlp_forward():
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP from transformers.models.gpt2.modeling_gpt2 import GPT2MLP

@ -25,7 +25,6 @@ from transformers.models.llama.modeling_llama import (
from transformers.utils import logging from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import AttnMaskType
from colossalai.shardformer.layer._operation import all_to_all_comm, gather_sp_output, split_forward_gather_backward from colossalai.shardformer.layer._operation import all_to_all_comm, gather_sp_output, split_forward_gather_backward
from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig
@ -58,10 +57,7 @@ class LlamaPipelineForwards:
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None, shard_config: ShardConfig = None,
# Split output only when computing cross entropy using llama_for_causal_lm_forward force_sp_gather: bool = True, # Set to false only when computing cross entropy
# or get_lm_forward_with_dist_cross_entropy
# Default to True to avoid bug when calling classification forward from huggingface
force_sp_output_gather: bool = True,
): ):
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -78,8 +74,9 @@ class LlamaPipelineForwards:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
disable_pp = stage_manager is None
# retrieve input_ids and inputs_embeds # retrieve input_ids and inputs_embeds
if stage_manager.is_first_stage(): if disable_pp or stage_manager.is_first_stage():
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
@ -88,10 +85,10 @@ class LlamaPipelineForwards:
batch_size, seq_length, _ = inputs_embeds.shape[:2] batch_size, seq_length, _ = inputs_embeds.shape[:2]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds hidden_states = inputs_embeds
device = hidden_states.device
else: else:
input_shape = hidden_states.shape[:-1] input_shape = hidden_states.shape[:-1]
batch_size, seq_length = input_shape batch_size, seq_length = input_shape
@ -101,8 +98,8 @@ class LlamaPipelineForwards:
sp_mode = shard_config.sequence_parallelism_mode sp_mode = shard_config.sequence_parallelism_mode
sp_group = shard_config.sequence_parallel_process_group sp_group = shard_config.sequence_parallel_process_group
sp_size = shard_config.sequence_parallel_size sp_size = shard_config.sequence_parallel_size
if sp_mode == "all_to_all" and not stage_manager.is_first_stage(): # Generating full positions ids for modes that gather sequence before attn
# For generating full positions ids, as the states will be gather along the seq dim in the attention layer later. if stage_manager and (sp_mode != "ring_attn" and not stage_manager.is_first_stage()):
seq_length *= sp_size seq_length *= sp_size
past_seen_tokens = 0 past_seen_tokens = 0
@ -117,7 +114,6 @@ class LlamaPipelineForwards:
seq_length_with_past = seq_length + past_seen_tokens seq_length_with_past = seq_length + past_seen_tokens
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions: if output_attentions:
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False output_attentions = False
@ -130,14 +126,13 @@ class LlamaPipelineForwards:
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
# embed positions, for the first stage, hidden_states is the input embeddings,
# for the other stages, hidden_states is the output of the previous stage no_split_input = disable_pp or not stage_manager.is_first_stage()
if not stage_manager.is_first_stage() and sp_mode == "ring_attn": if no_split_input and sp_mode == "ring_attn":
_, attn_kwargs, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group) _, attn_kwargs, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group)
elif shard_config.enable_flash_attention: elif shard_config.enable_flash_attention:
# in this case, attention_mask is a dict rather than a tensor
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
attn_kwargs = ColoAttention.prepare_attn_kwargs( attn_kwargs: dict = ColoAttention.prepare_attn_kwargs(
mask_shape, mask_shape,
hidden_states.dtype, hidden_states.dtype,
hidden_states.device, hidden_states.device,
@ -146,15 +141,15 @@ class LlamaPipelineForwards:
invert=(sp_mode != "ring_attn"), invert=(sp_mode != "ring_attn"),
) )
else: else:
attn_kwargs = self._update_causal_mask(attention_mask, hidden_states, cache_position) attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, hidden_states, cache_position)
# Support SP + PP # Support SP + PP. Later stages have already received the split input.
# TODO: support padded casual cu_seqlens across stages split_input = disable_pp or stage_manager.is_first_stage()
if stage_manager.is_first_stage(): if split_input:
# Ring Attention zigzag batch processing # Ring Attention zigzag batch processing
if sp_mode == "ring_attn": if sp_mode == "ring_attn":
assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention." assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention."
if attn_kwargs["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL: if not attention_mask.bool().all():
hidden_states, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch( hidden_states, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch(
attention_mask, sp_group, hidden_states, position_ids attention_mask, sp_group, hidden_states, position_ids
) )
@ -181,8 +176,8 @@ class LlamaPipelineForwards:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None all_self_attns = () if output_attentions else None
next_decoder_cache = None next_decoder_cache = None
start_idx, end_idx = (0, len(self.layers)) if disable_pp else (stage_index[0], stage_index[1])
start_idx, end_idx = stage_index[0], stage_index[1]
num_ckpt_layers = 0 num_ckpt_layers = 0
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
num_ckpt_layers = end_idx - start_idx num_ckpt_layers = end_idx - start_idx
@ -228,18 +223,16 @@ class LlamaPipelineForwards:
if output_attentions: if output_attentions:
all_self_attns += (layer_outputs[1],) all_self_attns += (layer_outputs[1],)
if stage_manager.is_last_stage(): if disable_pp or stage_manager.is_last_stage():
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): if (not shard_config.parallel_output) or force_sp_gather or is_share_sp_tp(sp_mode): # noqa
hidden_states = gather_sp_output( hidden_states = gather_sp_output(hidden_states, shard_config)
hidden_states, sp_group, sp_mode, fp8_communication=shard_config.fp8_communication
)
# add hidden states from the last decoder layer # add hidden states from the last decoder layer
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None next_cache = next_decoder_cache if use_cache else None
if stage_manager.is_last_stage(): if disable_pp or stage_manager.is_last_stage():
if not return_dict: if not return_dict:
return tuple( return tuple(
v v
@ -257,7 +250,7 @@ class LlamaPipelineForwards:
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attns, attentions=all_self_attns,
) )
# always return dict for imediate stage # always return dict for intermediate stage
return {"hidden_states": hidden_states} return {"hidden_states": hidden_states}
@staticmethod @staticmethod
@ -323,7 +316,7 @@ class LlamaPipelineForwards:
# Split labels in a zigzag fashion too # Split labels in a zigzag fashion too
sp_group = shard_config.sequence_parallel_process_group sp_group = shard_config.sequence_parallel_process_group
if attention_mask.bool().all(): if attention_mask.bool().all():
labels = split_batch_zigzag(labels, sp_group, seq_dim=1) labels = split_batch_zigzag(labels, sp_group, seq_dim=1, is_label=True)
else: else:
# [B, max_seqlen // sp_size] # [B, max_seqlen // sp_size]
labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True) labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True)
@ -345,16 +338,17 @@ class LlamaPipelineForwards:
hidden_states=hidden_states, hidden_states=hidden_states,
stage_index=stage_index, stage_index=stage_index,
shard_config=shard_config, shard_config=shard_config,
force_sp_output_gather=False, force_sp_gather=False,
) )
past_key_values = None past_key_values = None
if stage_manager.is_last_stage(): disable_pp = stage_manager is None
if disable_pp or stage_manager.is_last_stage():
hidden_states = outputs[0] hidden_states = outputs[0]
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
loss = dist_cross_entropy( loss = None
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype if labels is not None:
) loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype)
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
@ -629,263 +623,3 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
return forward return forward
def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
logger = logging.get_logger(__name__)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
# Split output only when computing cross entropy using llama_for_causal_lm_forward
# or get_lm_forward_with_dist_cross_entropy
# Default to True to avoid bug when calling classification forward from huggingface
force_sp_output_gather: bool = True,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
past_seen_tokens = 0
seq_len = inputs_embeds.shape[1]
batch_size = inputs_embeds.shape[0]
if use_cache: # kept for BC (cache positions)
if not isinstance(past_key_values, StaticCache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()
if cache_position is None:
if isinstance(past_key_values, StaticCache):
raise ValueError("cache_position is a required argument when using StaticCache.")
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_len, device=inputs_embeds.device)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
if shard_config.enable_flash_attention:
mask_shape = (batch_size, 1, seq_len, past_seen_tokens + seq_len)
attn_kwargs: dict = ColoAttention.prepare_attn_kwargs(
mask_shape,
inputs_embeds.dtype,
inputs_embeds.device,
q_padding_mask=attention_mask,
is_causal=True,
invert=(sp_mode != "ring_attn"),
)
else:
attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
# Ring Attention zigzag batch processing
if sp_mode == "ring_attn":
assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention."
if attn_kwargs["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL:
inputs_embeds, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch(
attention_mask, sp_group, inputs_embeds, position_ids
)
else:
inputs_embeds, position_ids = split_batch_zigzag([inputs_embeds, position_ids], sp_group)
attn_kwargs = {"attention_mask_type": attn_kwargs["attention_mask_type"]} # drop redundant tensors
elif is_share_sp_tp(sp_mode):
inputs_embeds = split_forward_gather_backward(
inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all":
inputs_embeds = split_forward_gather_backward(
inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
)
hidden_states = inputs_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attn_kwargs,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attn_kwargs,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# Cases that don't support parallelizing cross entropy computation along sequence
if (not shard_config.parallel_output) or is_share_sp_tp(sp_mode) or force_sp_output_gather:
hidden_states = gather_sp_output(
hidden_states, sp_group, sp_mode, fp8_communication=shard_config.fp8_communication
)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = None
if use_cache:
next_cache = (
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
)
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
return forward
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
from transformers import LlamaForCausalLM
def forward(
self: LlamaForCausalLM,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if shard_config.sequence_parallelism_mode == "ring_attn" and shard_config.parallel_output:
# Special processing: Split labels in a zigzag fashion too
sp_group = shard_config.sequence_parallel_process_group
if attention_mask.bool().all():
labels = split_batch_zigzag(labels, sp_group, seq_dim=1, is_label=True)
else:
# [B, max_seq_len // sp_size]
labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
force_sp_output_gather=False,
)
hidden_states = outputs[0]
if self.config.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
logits = torch.cat(logits, dim=-1)
else:
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = dist_cross_entropy(
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
return forward

@ -274,10 +274,9 @@ class MistralForwards:
hidden_states = outputs[0] hidden_states = outputs[0]
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
logits = logits.float() logits = logits.float()
loss = None
loss = dist_cross_entropy( if labels is not None:
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype)
)
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
@ -687,10 +686,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
hidden_states = outputs[0] hidden_states = outputs[0]
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
logits = logits.float() logits = logits.float()
loss = None
loss = dist_cross_entropy( if labels is not None:
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype)
)
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]

@ -330,12 +330,13 @@ class OPTPipelineForwards:
) )
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
logits = self.lm_head(outputs[0]).contiguous() logits = self.lm_head(outputs[0]).contiguous()
loss = None
if labels is not None:
loss = dist_cross_entropy( loss = dist_cross_entropy(
labels, labels,
logits, logits,
shard_config, shard_config,
self.lm_head.out_features, self.lm_head.out_features,
self.config.vocab_size,
self.model.decoder.dtype, self.model.decoder.dtype,
) )
@ -955,9 +956,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
) )
logits = self.lm_head(outputs[0]).contiguous() logits = self.lm_head(outputs[0]).contiguous()
loss = dist_cross_entropy( loss = None
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.decoder.dtype if labels is not None:
) loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.decoder.dtype)
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]

@ -32,14 +32,12 @@ except ImportError:
from transformers.utils import logging from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer._operation import ( from colossalai.shardformer.layer._operation import all_to_all_comm, split_forward_gather_backward
all_to_all_comm,
gather_forward_split_backward,
split_forward_gather_backward,
)
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig
from ..layer import ColoAttention, dist_cross_entropy from ..layer import ColoAttention, dist_cross_entropy
from ..layer._operation import gather_sp_output
from ..layer.utils import is_share_sp_tp
class Qwen2PipelineForwards: class Qwen2PipelineForwards:
@ -64,6 +62,7 @@ class Qwen2PipelineForwards:
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None, shard_config: ShardConfig = None,
force_sp_output_gather: bool = True,
) -> Union[Tuple, BaseModelOutputWithPast]: ) -> Union[Tuple, BaseModelOutputWithPast]:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -115,6 +114,14 @@ class Qwen2PipelineForwards:
past_key_values_length = past_key_values[0][0].shape[2] past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length seq_length_with_past = seq_length_with_past + past_key_values_length
# Support SP + PP
sp_size = shard_config.sequence_parallel_size
sp_group = shard_config.sequence_parallel_process_group
sp_mode = shard_config.sequence_parallelism_mode
# For generating full positions ids (the states will be gathered along the seq dim before attention fwd).
if sp_mode != "ring_attn" and not stage_manager.is_first_stage():
seq_length *= sp_size
if position_ids is None: if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange( position_ids = torch.arange(
@ -151,7 +158,6 @@ class Qwen2PipelineForwards:
elif self._attn_implementation == "sdpa" and not output_attentions: elif self._attn_implementation == "sdpa" and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on # output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases. # the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask, attention_mask,
(batch_size, seq_length), (batch_size, seq_length),
@ -160,7 +166,6 @@ class Qwen2PipelineForwards:
) )
else: else:
# 4d mask is passed through the layers # 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask( attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, attention_mask,
(batch_size, seq_length), (batch_size, seq_length),
@ -169,21 +174,20 @@ class Qwen2PipelineForwards:
sliding_window=self.config.sliding_window, sliding_window=self.config.sliding_window,
) )
if shard_config and shard_config.enable_sequence_parallelism: if stage_manager.is_first_stage():
if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]: if shard_config.enable_sequence_parallelism:
if is_share_sp_tp(sp_mode):
hidden_states = split_forward_gather_backward( hidden_states = split_forward_gather_backward(
hidden_states, hidden_states,
dim=1, dim=1,
process_group=shard_config.tensor_parallel_process_group, process_group=sp_group,
fp8_communication=shard_config.fp8_communication,
) )
elif shard_config.sequence_parallelism_mode == "all_to_all": elif sp_mode == "all_to_all":
hidden_states = split_forward_gather_backward( hidden_states = split_forward_gather_backward(
hidden_states, hidden_states,
dim=1, dim=1,
process_group=shard_config.sequence_parallel_process_group, process_group=sp_group,
grad_scale=1 / shard_config.sequence_parallel_size, grad_scale=1 / sp_size,
fp8_communication=shard_config.fp8_communication,
) )
# decoder layers # decoder layers
@ -241,23 +245,10 @@ class Qwen2PipelineForwards:
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
if shard_config.enable_sequence_parallelism:
if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode):
hidden_states = gather_sp_output(hidden_states, shard_config)
if shard_config and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
hidden_states = gather_forward_split_backward(
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = gather_forward_split_backward(
hidden_states,
dim=1,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=shard_config.sequence_parallel_size,
fp8_communication=shard_config.fp8_communication,
)
# add hidden states from the last decoder layer # add hidden states from the last decoder layer
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
@ -351,15 +342,18 @@ class Qwen2PipelineForwards:
hidden_states=hidden_states, hidden_states=hidden_states,
stage_index=stage_index, stage_index=stage_index,
shard_config=shard_config, shard_config=shard_config,
force_sp_output_gather=False,
) )
past_key_values = None past_key_values = None
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
hidden_states = outputs[0] hidden_states = outputs[0]
if hidden_states.shape[1] == 2:
pass
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
loss = dist_cross_entropy( loss = None
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, logits.dtype if labels is not None:
) loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, logits.dtype)
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
@ -541,7 +535,6 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
# Because the input can be padded, the absolute sequence length depends on the max position id. # Because the input can be padded, the absolute sequence length depends on the max position id.
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None: if past_key_value is not None:
@ -635,6 +628,7 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
force_sp_output_gather: bool = True,
) -> Union[Tuple, BaseModelOutputWithPast]: ) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
@ -750,14 +744,9 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
if sp_mode == "ring" or sp_mode == "split_gather": if shard_config.enable_sequence_parallelism:
hidden_states = gather_forward_split_backward( if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode):
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication hidden_states = gather_sp_output(hidden_states, shard_config)
)
elif sp_mode == "all_to_all":
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
)
# add hidden states from the last decoder layer # add hidden states from the last decoder layer
if output_hidden_states: if output_hidden_states:
@ -834,14 +823,15 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
force_sp_output_gather=False,
) )
hidden_states = outputs[0] hidden_states = outputs[0]
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
logits = logits.float() logits = logits.float()
loss = dist_cross_entropy( loss = None
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, logits.dtype if labels is not None:
) loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, logits.dtype)
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]

@ -64,7 +64,7 @@ class ChatGLMPolicy(Policy):
if sp_mode == "ring": if sp_mode == "ring":
warnings.warn( warnings.warn(
f"For ChatGLM2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather" f"For ChatGLM2, sequence parallelism doesn't support mode {sp_mode} yet, will set to be split_gather"
) )
sp_mode = "split_gather" sp_mode = "split_gather"
overlap = self.shard_config.enable_sequence_overlap overlap = self.shard_config.enable_sequence_overlap

@ -6,14 +6,7 @@ from torch import Tensor, nn
import colossalai.shardformer.layer as col_nn import colossalai.shardformer.layer as col_nn
from ..modeling.gpt2 import ( from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward, get_jit_fused_gpt2_mlp_forward
GPT2PipelineForwards,
get_gpt2_flash_attention_forward,
get_gpt_model_forward_for_flash_attn,
get_jit_fused_gpt2_mlp_forward,
get_lm_forward_with_dist_cross_entropy,
gpt2_sequence_parallel_forward_fn,
)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [ __all__ = [
@ -71,18 +64,10 @@ class GPT2Policy(Policy):
warnings.warn( warnings.warn(
f"For GPT2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather" f"For GPT2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather"
) )
sp_mode = "split_gather" self.shard_config.sequence_parallelism_mode = sp_mode = "split_gather"
overlap = self.shard_config.enable_sequence_overlap overlap = self.shard_config.enable_sequence_overlap
sp_partial_derived = sp_mode in ["split_gather", "ring"] sp_partial_derived = sp_mode in ["split_gather", "ring"]
use_flash_attention = self.shard_config.enable_flash_attention use_flash_attention = self.shard_config.enable_flash_attention
# todo: currently sp cannot be used with flashattention
if sp_mode in ["split_gather", "ring", "all_to_all"]:
if use_flash_attention:
warnings.warn(
f"Sequence parallelism mode {sp_mode} cannot be used with FlashAttention, will disable FlashAttention automatically."
)
self.shard_config.enable_flash_attention = False
use_flash_attention = False
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
assert ( assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
@ -211,19 +196,17 @@ class GPT2Policy(Policy):
if use_flash_attention: if use_flash_attention:
self.append_or_create_method_replacement( self.append_or_create_method_replacement(
description={ description={
"forward": get_gpt2_flash_attention_forward(), "forward": get_gpt2_flash_attention_forward(shard_config=self.shard_config),
}, },
policy=policy, policy=policy,
target_key=attn_cls, target_key=attn_cls,
) )
if not self.shard_config.pipeline_stage_manager:
if not self.shard_config.pipeline_stage_manager and self.shard_config.enable_sequence_parallelism:
policy[GPT2Model].method_replacement = { policy[GPT2Model].method_replacement = {
"forward": get_gpt_model_forward_for_flash_attn(self.shard_config) "forward": partial(GPT2PipelineForwards.gpt2_model_forward, shard_config=self.shard_config)
} }
if sp_mode is not None:
policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)}
return policy return policy
def postprocess(self): def postprocess(self):
@ -328,40 +311,39 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
module_policy = super().module_policy() module_policy = super().module_policy()
module_policy[GPT2LMHeadModel] = ModulePolicyDescription()
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
addon_module = { self.append_or_create_submodule_replacement(
GPT2LMHeadModel: ModulePolicyDescription( description=SubModuleReplacementDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", suffix="lm_head",
target_module=col_nn.VocabParallelLMHead1D, target_module=col_nn.VocabParallelLMHead1D,
kwargs={ kwargs={
"gather_output": False, "gather_output": False,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
}, },
),
policy=module_policy,
target_key=GPT2LMHeadModel,
) )
],
)
}
if self.shard_config.parallel_output:
addon_module[GPT2LMHeadModel].method_replacement = {
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
}
else: else:
addon_module = { self.append_or_create_submodule_replacement(
GPT2LMHeadModel: ModulePolicyDescription( description=SubModuleReplacementDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", suffix="lm_head",
target_module=col_nn.PaddingLMHead, target_module=col_nn.PaddingLMHead,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
policy=module_policy,
target_key=GPT2LMHeadModel,
) )
]
if self.shard_config.parallel_output:
self.append_or_create_method_replacement(
description={
"forward": partial(GPT2PipelineForwards.gpt2_lmhead_model_forward, shard_config=self.shard_config)
},
policy=module_policy,
target_key=GPT2LMHeadModel,
) )
}
module_policy.update(addon_module)
if self.pipeline_stage_manager is not None: if self.pipeline_stage_manager is not None:
self.set_pipeline_forward( self.set_pipeline_forward(

@ -16,12 +16,7 @@ from colossalai.shardformer.layer import (
VocabParallelLMHead1D, VocabParallelLMHead1D,
) )
from ..modeling.llama import ( from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward
LlamaPipelineForwards,
get_llama_flash_attention_forward,
get_llama_flash_attention_model_forward,
get_lm_forward_with_dist_cross_entropy,
)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["LlamaPolicy", "LlamaForCausalLMPolicy", "LlamaForSequenceClassificationPolicy"] __all__ = ["LlamaPolicy", "LlamaForCausalLMPolicy", "LlamaForSequenceClassificationPolicy"]
@ -99,11 +94,9 @@ class LlamaPolicy(Policy):
if self.pipeline_stage_manager is None: if self.pipeline_stage_manager is None:
self.append_or_create_method_replacement( self.append_or_create_method_replacement(
description={ description={
"forward": get_llama_flash_attention_model_forward( "forward": partial(
self.shard_config, LlamaPipelineForwards.llama_model_forward,
sp_mode=sp_mode, shard_config=self.shard_config,
sp_size=sp_size,
sp_group=sp_group,
), ),
}, },
policy=policy, policy=policy,
@ -351,7 +344,8 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
elif self.shard_config.enable_tensor_parallelism or self.shard_config.enable_sequence_parallelism: elif self.shard_config.enable_tensor_parallelism or self.shard_config.enable_sequence_parallelism:
# Compute loss distributedly along the sequence dimension # Compute loss distributedly along the sequence dimension
new_item[LlamaForCausalLM].method_replacement = { new_item[LlamaForCausalLM].method_replacement = {
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config) # "forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
"forward": partial(LlamaPipelineForwards.llama_for_causal_lm_forward, shard_config=self.shard_config)
} }
return policy return policy

@ -28,7 +28,7 @@ MODEL_CONFIGS = {
"118M": GPT2Config(activation_function="gelu"), "118M": GPT2Config(activation_function="gelu"),
"338M": GPT2Config(n_embd=1024, n_head=16, n_layer=24, activation_function="gelu"), "338M": GPT2Config(n_embd=1024, n_head=16, n_layer=24, activation_function="gelu"),
"738M": GPT2Config(n_embd=1280, n_head=20, n_layer=36, activation_function="gelu"), "738M": GPT2Config(n_embd=1280, n_head=20, n_layer=36, activation_function="gelu"),
"6.21B": GPT2Config(n_embd=4096, n_head=32, n_layer=32, n_positions=4096, activation_function="gelu"), "6.21B": GPT2Config(n_embd=4096, n_head=32, n_layer=32, n_positions=32768, activation_function="gelu"),
} }
@ -60,6 +60,8 @@ def main():
parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size") parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size")
parser.add_argument("--extra_dp", type=int, default=1, help="Extra data parallel size, used for Gemini") parser.add_argument("--extra_dp", type=int, default=1, help="Extra data parallel size, used for Gemini")
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size") parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size")
parser.add_argument("--sp", type=int, default=1, help="Sequence parallel size")
parser.add_argument("--sp_mode", type=str, default="ring_attn", help="Sequence parallel mode")
parser.add_argument("--mbs", type=int, default=1) parser.add_argument("--mbs", type=int, default=1)
parser.add_argument("--zero", type=int, default=0) parser.add_argument("--zero", type=int, default=0)
parser.add_argument("--pp_style", type=str, default="1f1b") parser.add_argument("--pp_style", type=str, default="1f1b")
@ -129,6 +131,9 @@ def main():
tp_size=args.tp, tp_size=args.tp,
pp_size=args.pp, pp_size=args.pp,
pp_style=args.pp_style, pp_style=args.pp_style,
sp_size=args.sp,
sequence_parallelism_mode=args.sp_mode,
enable_sequence_parallelism=True,
zero_stage=args.zero, zero_stage=args.zero,
num_model_chunks=args.num_model_chunks, num_model_chunks=args.num_model_chunks,
enable_all_optimization=True, enable_all_optimization=True,
@ -214,6 +219,8 @@ def main():
performance_evaluator.on_step_start(step) performance_evaluator.on_step_start(step)
outputs = model(**batch) outputs = model(**batch)
loss = outputs[0] loss = outputs[0]
del outputs
booster.backward(loss, optimizer) booster.backward(loss, optimizer)
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()

@ -6,7 +6,6 @@ import torch.distributed as dist
from torch import Tensor from torch import Tensor
from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler
from colossalai.accelerator import get_accelerator
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
@ -22,8 +21,11 @@ def divide(x: float, y: float) -> float:
def all_reduce_mean(x: float, world_size: int) -> float: def all_reduce_mean(x: float, world_size: int) -> float:
if world_size == 1: if world_size == 1:
return x return x
tensor = torch.tensor([x], device=get_accelerator().get_current_device())
dist.all_reduce(tensor) # Use CPU tensor to avoid OOM/weird NCCl error
gloo_group = dist.new_group(backend="gloo")
tensor = torch.tensor([x], device="cpu")
dist.all_reduce(tensor, group=gloo_group)
tensor = tensor / world_size tensor = tensor / world_size
return tensor.item() return tensor.item()

@ -27,7 +27,16 @@ def data_gen_for_lm():
# LM data gen # LM data gen
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
data = data_gen() data = data_gen()
data["labels"] = data["input_ids"].clone()
# Test padded sequence for Ring Attention
padding = torch.zeros(1, data["input_ids"].shape[1] // 2, dtype=torch.long)
data["input_ids"] = torch.cat([data["input_ids"], padding], dim=1)
data["attention_mask"] = torch.cat([data["attention_mask"], padding], dim=1)
ignore_idx = -100
labels = data["input_ids"].clone()
labels[~data["attention_mask"].bool()] = ignore_idx
data["labels"] = labels
return data return data

@ -157,7 +157,6 @@ def build_model_from_hybrid_plugin(
sharded_optimizer = sharded_optim_class(sharded_model.parameters(), lr=1e-3) sharded_optimizer = sharded_optim_class(sharded_model.parameters(), lr=1e-3)
criterion = loss_fn criterion = loss_fn
plugin = pluggin_cls(**test_config) plugin = pluggin_cls(**test_config)
booster = Booster(plugin=plugin) booster = Booster(plugin=plugin)
@ -323,7 +322,6 @@ def check_output_hidden_state(
sp_size = shard_config.sequence_parallel_size sp_size = shard_config.sequence_parallel_size
if org_hidden_state.shape[seq_dim] == sharded_hidden_state.shape[seq_dim] * sp_size: if org_hidden_state.shape[seq_dim] == sharded_hidden_state.shape[seq_dim] * sp_size:
org_hidden_state = org_hidden_state.chunk(sp_size, dim=seq_dim)[dist.get_rank(sp_group)] org_hidden_state = org_hidden_state.chunk(sp_size, dim=seq_dim)[dist.get_rank(sp_group)]
assert_close(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol) assert_close(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol)

@ -136,26 +136,25 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
@parameterize( @parameterize(
"test_config", "test_config",
[ [
{ # Ulysess + Flash attention {
"tp_size": 1, "tp_size": 2,
"pp_size": 2, "pp_size": 2,
"sp_size": 2,
"num_microbatches": 2, "num_microbatches": 2,
"enable_sequence_parallelism": True, "enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all", "sequence_parallelism_mode": "split_gather",
"enable_flash_attention": True, "enable_flash_attention": True,
"use_lazy_init": True, "use_lazy_init": True,
"zero_stage": 1, "zero_stage": 1,
"precision": "fp16", "precision": "fp16",
"initial_scale": 1, "initial_scale": 1,
}, },
{ { # Ulysess + Flash attention
"tp_size": 2, "tp_size": 1,
"pp_size": 2, "pp_size": 2,
"sp_size": 2, "sp_size": 2,
"num_microbatches": 2, "num_microbatches": 2,
"enable_sequence_parallelism": True, "enable_sequence_parallelism": True,
"sequence_parallelism_mode": "split_gather", "sequence_parallelism_mode": "all_to_all",
"enable_flash_attention": True, "enable_flash_attention": True,
"use_lazy_init": True, "use_lazy_init": True,
"zero_stage": 1, "zero_stage": 1,
@ -174,17 +173,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"precision": "fp16", "precision": "fp16",
"initial_scale": 1, "initial_scale": 1,
}, },
{
"tp_size": 4,
"pp_size": 1,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "ring",
"enable_flash_attention": False,
"use_lazy_init": True,
"precision": "fp32",
"initial_scale": 1,
},
{ {
"tp_size": 4, "tp_size": 4,
"pp_size": 1, "pp_size": 1,
@ -248,7 +236,11 @@ def run_chatglm_test(test_config):
loss_fn, loss_fn,
_, _,
) in sub_model_zoo.items(): ) in sub_model_zoo.items():
try:
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
except Exception as e:
print(f"Test config failed for model {name}: {test_config}")
raise e
clear_layout_converter() clear_layout_converter()
torch.cuda.empty_cache() torch.cuda.empty_cache()

@ -125,7 +125,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
atol, rtol = 5e-3, 5e-3 atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == "CohereModel": if org_model.__class__.__name__ == "CohereModel":
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) check_output_hidden_state(
org_output,
sharded_output,
stage_manager,
atol=atol,
rtol=rtol,
shard_config=booster.plugin.shard_config,
)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
@ -274,7 +281,11 @@ def run_command_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_causal_lm") sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_causal_lm")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
try:
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
except Exception as e:
print(f"Failed test config: {test_config}")
raise e
clear_layout_converter() clear_layout_converter()
Randomizer.reset_index() Randomizer.reset_index()

@ -100,7 +100,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
atol, rtol = 5e-3, 5e-3 atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == "GPT2Model": if org_model.__class__.__name__ == "GPT2Model":
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) check_output_hidden_state(
org_output,
sharded_output,
stage_manager,
atol=atol,
rtol=rtol,
shard_config=booster.plugin.shard_config,
)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
@ -132,14 +139,27 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"test_config", "test_config",
[ [
{ {
"tp_size": 4, "sp_size": 2,
"tp_size": 1,
"pp_size": 2,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "ring_attn",
"num_microbatches": 2,
"enable_all_optimization": True,
"use_lazy_init": True,
"precision": "fp16",
"initial_scale": 1,
},
{
"sp_size": 2,
"tp_size": 2,
"pp_size": 1, "pp_size": 1,
"num_microbatches": 1,
"enable_sequence_parallelism": True, "enable_sequence_parallelism": True,
"sequence_parallelism_mode": "ring", "sequence_parallelism_mode": "ring_attn",
"enable_flash_attention": False, "num_microbatches": 1,
"enable_all_optimization": True,
"use_lazy_init": True, "use_lazy_init": True,
"precision": "fp32", "precision": "fp16",
"initial_scale": 1, "initial_scale": 1,
}, },
{ {
@ -148,7 +168,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"num_microbatches": 1, "num_microbatches": 1,
"enable_sequence_parallelism": True, "enable_sequence_parallelism": True,
"sequence_parallelism_mode": "split_gather", "sequence_parallelism_mode": "split_gather",
"enable_flash_attention": False, "enable_flash_attention": True,
"use_lazy_init": True, "use_lazy_init": True,
"precision": "fp16", "precision": "fp16",
"initial_scale": 1, "initial_scale": 1,
@ -156,7 +176,18 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
{ {
"tp_size": 2, "tp_size": 2,
"pp_size": 2, "pp_size": 2,
"num_microbatches": 4, "num_microbatches": 2,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "split_gather",
"enable_flash_attention": True,
"use_lazy_init": True,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 2,
"enable_all_optimization": True, "enable_all_optimization": True,
"use_lazy_init": True, "use_lazy_init": True,
"precision": "fp16", "precision": "fp16",
@ -185,7 +216,16 @@ def run_gpt2_test(test_config):
loss_fn, loss_fn,
_, _,
) in sub_model_zoo.items(): ) in sub_model_zoo.items():
if test_config.get("sequence_parallelism_mode", None) == "ring_attn" and name != "transformers_gpt_lm":
# Only wrote zigzag splitting for cross entropy loss
continue
try:
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
except Exception as e:
print(f"Failed config: {test_config} for model {name}")
raise (e)
clear_layout_converter() clear_layout_converter()
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -226,7 +266,11 @@ def run_gpt2_3d_test(test_config):
loss_fn, loss_fn,
_, _,
) in sub_model_zoo.items(): ) in sub_model_zoo.items():
try:
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
except Exception as e:
print(f"Failed config: {test_config} for model {name}")
raise (e)
clear_layout_converter() clear_layout_converter()
torch.cuda.empty_cache() torch.cuda.empty_cache()

@ -165,7 +165,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"zero_stage": 0, "zero_stage": 0,
"precision": "fp16", "precision": "fp16",
"initial_scale": 1, "initial_scale": 1,
"inner_ring_size": 2,
}, },
# Ring Attention + PP # Ring Attention + PP
{ {
@ -215,18 +214,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"sequence_parallelism_mode": "all_to_all", "sequence_parallelism_mode": "all_to_all",
"enable_all_optimization": True, "enable_all_optimization": True,
"use_lazy_init": True, "use_lazy_init": True,
"zero_stage": 0, "zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 4,
"pp_size": 1,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "split_gather",
"enable_flash_attention": True,
"use_lazy_init": True,
"precision": "fp16", "precision": "fp16",
"initial_scale": 1, "initial_scale": 1,
}, },
@ -294,6 +282,7 @@ def run_llama_test(test_config):
except Exception as e: except Exception as e:
print(f"Failed config: {test_config}, model name: {name}") print(f"Failed config: {test_config}, model name: {name}")
raise e raise e
clear_layout_converter() clear_layout_converter()
Randomizer.reset_index() Randomizer.reset_index()
torch.cuda.empty_cache() torch.cuda.empty_cache()

@ -94,6 +94,32 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
@parameterize( @parameterize(
"test_config", "test_config",
[ [
{
"tp_size": 2,
"pp_size": 2,
"sp_size": 2,
"num_microbatches": 2,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "split_gather",
"enable_flash_attention": True,
"use_lazy_init": True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
{ # Ulysess + Flash attention
"tp_size": 1,
"pp_size": 2,
"sp_size": 2,
"num_microbatches": 2,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
"enable_flash_attention": True,
"use_lazy_init": True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
{ {
"tp_size": 2, "tp_size": 2,
"pp_size": 2, "pp_size": 2,
@ -135,32 +161,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"precision": "fp16", "precision": "fp16",
"initial_scale": 1, "initial_scale": 1,
}, },
{ # Ulysess + Flash attention
"tp_size": 1,
"pp_size": 2,
"sp_size": 2,
"num_microbatches": 2,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
"enable_flash_attention": True,
"use_lazy_init": True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 2,
"sp_size": 2,
"num_microbatches": 2,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "split_gather",
"enable_flash_attention": True,
"use_lazy_init": True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
{ {
"tp_size": 2, "tp_size": 2,
"pp_size": 2, "pp_size": 2,

Loading…
Cancel
Save