mirror of https://github.com/hpcaitech/ColossalAI
[Feature] Split cross-entropy computation in SP (#5959)
* halfway * fix cross-PP-stage position id length diff bug * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * unified cross entropy func for all shardformer models * remove redundant lines * add basic ring attn; debug cross entropy * fwd bwd logic complete * fwd bwd logic complete; add experimental triton rescale * precision tests passed * precision tests passed * fix typos and remove misc files * update softmax_lse shape by new interface * change tester name * remove buffer clone; support packed seq layout * add varlen tests * fix typo * all tests passed * add dkv_group; fix mask * remove debug statements * adapt chatglm, command-R, qwen * debug * halfway * fix cross-PP-stage position id length diff bug * fix typo * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * unified cross entropy func for all shardformer models * remove redundant lines * add basic ring attn; debug cross entropy * fwd bwd logic complete * fwd bwd logic complete; add experimental triton rescale * precision tests passed * precision tests passed * fix typos and remove misc files * add sp_mode to benchmark; fix varlen interface * update softmax_lse shape by new interface * add varlen tests * fix typo * all tests passed * add dkv_group; fix mask * remove debug statements * add comments * q1 index only once * remove events to simplify stream sync * simplify forward/backward logic * 2d ring forward passed * 2d ring backward passed * fixes * fix ring attn loss * 2D ring backward + llama passed * merge * update logger * fix typo * rebase * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typo * remove typos * fixes * support GPT --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/6053/head
parent
b3db1058ec
commit
8fd25d6e09
|
@ -1097,13 +1097,19 @@ def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1, fp8
|
|||
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim, fp8_communication)
|
||||
|
||||
|
||||
def gather_sp_output(hidden_states, sp_group, sp_mode, fp8_communication=False):
|
||||
def gather_sp_output(hidden_states, shard_config, sp_dim=1):
|
||||
"""
|
||||
Gather the output of the last layer for cross entropy computation
|
||||
"""
|
||||
sp_group = shard_config.sequence_parallel_process_group
|
||||
sp_mode = shard_config.sequence_parallelism_mode
|
||||
fp8_comm = shard_config.fp8_communication
|
||||
if dist.get_world_size(sp_group) == 1:
|
||||
return hidden_states
|
||||
|
||||
# Rescale grad (HybridParallelPlugin applies ZeRO grad averaging on the DP * SP group)
|
||||
scale = None if is_share_sp_tp(sp_mode) else dist.get_world_size(sp_group)
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, grad_scale=scale, fp8_communication=fp8_communication
|
||||
hidden_states, sp_dim, sp_group, grad_scale=scale, fp8_communication=fp8_comm
|
||||
)
|
||||
return hidden_states
|
||||
|
|
|
@ -433,7 +433,6 @@ class RingAttention(torch.autograd.Function):
|
|||
assert (
|
||||
sp_size % inner_ring_size == 0
|
||||
), f"sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}"
|
||||
|
||||
logger = get_dist_logger()
|
||||
logger.info(
|
||||
f"Using 2D Ring Attention with inner ring size {inner_ring_size} to maximze NIC util for inter-node comm. Cross your fingers for speed-ups!",
|
||||
|
@ -898,6 +897,7 @@ class RingAttention(torch.autograd.Function):
|
|||
|
||||
local_sp_rank = dist.get_rank(sp_group)
|
||||
sp_size = dist.get_world_size(sp_group)
|
||||
|
||||
# Using separate streams (pg) for concurrent kv and dkv comm may
|
||||
# cause NCCL "software caused connection abort" here...
|
||||
local_kv_comm = RingComm(local_kv_group)
|
||||
|
@ -1119,9 +1119,14 @@ class RingAttention(torch.autograd.Function):
|
|||
the batch dim to a packed 1d sequence. Contingent on model forward shape definitions.
|
||||
|
||||
Returns:
|
||||
inputs_embeds: Packed input embeddings of shape [B, Sq // sp_size, ...].
|
||||
mask_info: A dictionary of mask info.
|
||||
position_ids: Packed position ids of shape [..., Sq // sp_size].
|
||||
torch.Tensor:
|
||||
Packed input embeddings of shape [B, Sq // sp_size, ...].
|
||||
|
||||
Dict[str, Any]:
|
||||
A dictionary containing mask info.
|
||||
|
||||
torch.Tensor:
|
||||
Packed position ids of shape [..., Sq // sp_size].
|
||||
|
||||
"""
|
||||
_load_varlen_helpers()
|
||||
|
|
|
@ -153,7 +153,6 @@ def dist_cross_entropy(
|
|||
labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
|
||||
logits: torch.Tensor, # [B, S, Vocab_size]
|
||||
shard_config: ShardConfig,
|
||||
out_features: int,
|
||||
vocab_size: int,
|
||||
dtype: torch.dtype,
|
||||
seq_dim: int = 1,
|
||||
|
@ -226,13 +225,13 @@ def dist_cross_entropy(
|
|||
logits,
|
||||
labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=out_features,
|
||||
vocab_size=vocab_size,
|
||||
dtype=dtype,
|
||||
mode="sum",
|
||||
)
|
||||
else:
|
||||
# NOTE if use TP and not parallel_output, the output is gathered in VocabParallelLMHead1D
|
||||
logits = logits.view(-1, vocab_size)
|
||||
logits = logits.view(-1, logits.size(-1))
|
||||
loss = loss_fct(logits, labels)
|
||||
|
||||
# Reduce loss instead of gathering logits over seq dim for savings
|
||||
|
|
|
@ -313,19 +313,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
|||
|
||||
# Matrix multiply.
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
if self.seq_parallel_mode is None:
|
||||
# Set up backprop all-reduce.
|
||||
input_parallel = reduce_backward(input_, self.process_group, fp8_communication=self.fp8_communication)
|
||||
output_parallel = matmul_with_async_comm(
|
||||
input_parallel,
|
||||
self.weight,
|
||||
bias,
|
||||
self.process_group,
|
||||
self.async_communication,
|
||||
fp8_communication=self.fp8_communication,
|
||||
)
|
||||
elif self.seq_parallel_mode == "split_gather":
|
||||
if self.seq_parallel_mode == "split_gather":
|
||||
input_parallel = input_
|
||||
output_parallel = matmul_gather_forward_reducescatter_backward(
|
||||
input_parallel,
|
||||
|
@ -340,8 +328,29 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
|||
elif self.seq_parallel_mode == "ring":
|
||||
input_parallel = input_
|
||||
output_parallel = matmul_gather_forward_reducescatter_backward(
|
||||
input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap, True
|
||||
input_parallel,
|
||||
self.weight,
|
||||
bias,
|
||||
self.process_group,
|
||||
True,
|
||||
1,
|
||||
self.overlap,
|
||||
True,
|
||||
fp8_communication=self.fp8_communication,
|
||||
)
|
||||
elif self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn":
|
||||
# Set up backprop all-reduce.
|
||||
input_parallel = reduce_backward(input_, self.process_group)
|
||||
output_parallel = matmul_with_async_comm(
|
||||
input_parallel,
|
||||
self.weight,
|
||||
bias,
|
||||
self.process_group,
|
||||
self.async_communication,
|
||||
fp8_communication=self.fp8_communication,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"seq_parallel_mode={self.seq_parallel_mode} is not supported!")
|
||||
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
|
@ -553,7 +562,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
|||
handle.wait()
|
||||
output = torch.cat(output_parallel_list, dim=-1)
|
||||
else:
|
||||
if self.seq_parallel_mode is None:
|
||||
if self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn":
|
||||
output_parallel = torch.matmul(input_, self.weight)
|
||||
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
|
||||
elif self.seq_parallel_mode == "split_gather":
|
||||
|
@ -567,8 +576,12 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
|||
elif self.seq_parallel_mode == "ring":
|
||||
output_parallel = torch.matmul(input_, self.weight)
|
||||
output = reducescatter_forward_gather_backward(
|
||||
output_parallel, self.process_group, 1, self.fp8_communication
|
||||
output_parallel,
|
||||
self.process_group,
|
||||
1,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"seq_parallel_mode={self.seq_parallel_mode} is not supported!")
|
||||
|
||||
if not self.skip_bias_add:
|
||||
if self.bias is not None:
|
||||
|
|
|
@ -309,6 +309,9 @@ def split_batch_zigzag(
|
|||
"""
|
||||
sp_size = dist.get_world_size(sp_group)
|
||||
sp_rank = dist.get_rank(sp_group)
|
||||
if sp_size == 1:
|
||||
return batch
|
||||
|
||||
if isinstance(batch, torch.Tensor):
|
||||
batch = [batch]
|
||||
seq_dim = seq_dim if seq_dim != -1 else batch[0].dim() - 1
|
||||
|
@ -364,6 +367,9 @@ def split_varlen_zigzag(
|
|||
"""
|
||||
sp_size = dist.get_world_size(sp_group)
|
||||
sp_rank = dist.get_rank(sp_group)
|
||||
if sp_size == 1:
|
||||
return batch
|
||||
|
||||
if is_2d:
|
||||
assert max_seqlen > 0, "max_seqlen must be provided for 2D input"
|
||||
|
||||
|
|
|
@ -365,14 +365,15 @@ class BloomPipelineForwards:
|
|||
hidden_states = transformer_outputs[0]
|
||||
lm_logits = self.lm_head(hidden_states).contiguous()
|
||||
|
||||
loss = dist_cross_entropy(
|
||||
labels,
|
||||
lm_logits,
|
||||
shard_config,
|
||||
self.lm_head.out_features,
|
||||
self.config.vocab_size,
|
||||
self.transformer.dtype,
|
||||
)
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = dist_cross_entropy(
|
||||
labels,
|
||||
lm_logits,
|
||||
shard_config,
|
||||
self.lm_head.out_features,
|
||||
self.transformer.dtype,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + transformer_outputs[1:]
|
||||
|
@ -1036,9 +1037,11 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||
hidden_states = transformer_outputs[0]
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
||||
loss = dist_cross_entropy(
|
||||
labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype
|
||||
)
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = dist_cross_entropy(
|
||||
labels, lm_logits, shard_config, self.lm_head.out_features, self.transformer.dtype
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + transformer_outputs[1:]
|
||||
|
|
|
@ -4,7 +4,6 @@ from typing import List, Optional, Tuple
|
|||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from transformers.utils import logging
|
||||
|
||||
|
@ -13,10 +12,13 @@ from colossalai.shardformer import ShardConfig
|
|||
from colossalai.shardformer.layer import AttnMaskType, ColoAttention
|
||||
from colossalai.shardformer.layer._operation import (
|
||||
all_to_all_comm,
|
||||
gather_forward_split_backward,
|
||||
gather_sp_output,
|
||||
is_share_sp_tp,
|
||||
split_forward_gather_backward,
|
||||
)
|
||||
|
||||
from ..layer import dist_cross_entropy
|
||||
|
||||
|
||||
def get_flash_core_attention_forward():
|
||||
from .chatglm2_6b.modeling_chatglm import CoreAttention
|
||||
|
@ -138,6 +140,7 @@ class ChatGLMPipelineForwards:
|
|||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
force_sp_output_gather: Optional[bool] = True,
|
||||
):
|
||||
logger = logging.get_logger(__name__)
|
||||
output_hidden_states = (
|
||||
|
@ -180,6 +183,15 @@ class ChatGLMPipelineForwards:
|
|||
if full_attention_mask is None:
|
||||
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
|
||||
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
|
||||
|
||||
# Support SP + PP
|
||||
sp_size = shard_config.sequence_parallel_size
|
||||
sp_mode = shard_config.sequence_parallelism_mode
|
||||
sp_group = shard_config.sequence_parallel_process_group
|
||||
# For generating full positions ids (the states will be gathered along the seq dim before attention fwd).
|
||||
if sp_mode != "ring_attn" and not stage_manager.is_first_stage():
|
||||
seq_length *= sp_size
|
||||
|
||||
# Rotary positional embeddings
|
||||
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
|
||||
if position_ids is not None:
|
||||
|
@ -200,29 +212,23 @@ class ChatGLMPipelineForwards:
|
|||
all_hidden_states = () if output_hidden_states else None
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
|
||||
if shard_config and shard_config.enable_sequence_parallelism:
|
||||
if shard_config.sequence_parallelism_mode == "split_gather":
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states,
|
||||
dim=0,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states,
|
||||
dim=0,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
grad_scale=1 / shard_config.sequence_parallel_size,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states,
|
||||
dim=0,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
grad_scale=1 / shard_config.sequence_parallel_size,
|
||||
)
|
||||
# Keep the input split across all PP stages
|
||||
if stage_manager.is_first_stage():
|
||||
if shard_config.enable_sequence_parallelism:
|
||||
if sp_mode == "split_gather":
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states,
|
||||
dim=0,
|
||||
process_group=sp_group,
|
||||
)
|
||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states,
|
||||
dim=0,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
grad_scale=1 / shard_config.sequence_parallel_size,
|
||||
)
|
||||
|
||||
for idx in range(start_idx, end_idx):
|
||||
layer = self.encoder._get_layer(idx)
|
||||
if output_hidden_states:
|
||||
|
@ -248,35 +254,19 @@ class ChatGLMPipelineForwards:
|
|||
if use_cache:
|
||||
presents = presents + (kv_cache,)
|
||||
|
||||
if shard_config and shard_config.enable_sequence_parallelism:
|
||||
if shard_config.sequence_parallelism_mode == "split_gather":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states,
|
||||
dim=0,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states,
|
||||
dim=0,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
grad_scale=shard_config.sequence_parallel_size,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states,
|
||||
dim=0,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
grad_scale=shard_config.sequence_parallel_size,
|
||||
)
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
if stage_manager.is_last_stage():
|
||||
# final layer_norm
|
||||
if self.encoder.post_layer_norm:
|
||||
hidden_states = self.encoder.final_layernorm(hidden_states)
|
||||
|
||||
# Gather seq-wise in the final output stage
|
||||
if shard_config.enable_sequence_parallelism:
|
||||
sp_mode = shard_config.sequence_parallelism_mode
|
||||
if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode):
|
||||
hidden_states = gather_sp_output(hidden_states, shard_config, sp_dim=0)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
|
@ -333,6 +323,7 @@ class ChatGLMPipelineForwards:
|
|||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config,
|
||||
force_sp_output_gather=False,
|
||||
)
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = transformer_outputs[0]
|
||||
|
@ -340,17 +331,21 @@ class ChatGLMPipelineForwards:
|
|||
hidden_states = hidden_states[-1:]
|
||||
lm_logits = self.transformer.output_layer(hidden_states)
|
||||
lm_logits = lm_logits.transpose(0, 1).contiguous()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
lm_logits = lm_logits.to(torch.float32)
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-100)
|
||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
lm_logits = lm_logits.to(hidden_states.dtype)
|
||||
loss = loss.to(hidden_states.dtype)
|
||||
# ChatGLM doesn't have lm_head split
|
||||
enable_tp = shard_config.enable_tensor_parallelism
|
||||
shard_config.enable_tensor_parallelism = False
|
||||
loss = dist_cross_entropy(
|
||||
labels,
|
||||
lm_logits,
|
||||
shard_config,
|
||||
self.transformer.output_layer.out_features,
|
||||
lm_logits.dtype,
|
||||
)
|
||||
shard_config.enable_tensor_parallelism = enable_tp
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
@ -379,6 +374,7 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
|
|||
use_cache: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
force_sp_output_gather: Optional[bool] = True,
|
||||
):
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
|
@ -456,22 +452,9 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
|
|||
use_cache=use_cache,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
if sp_mode in ["split_gather"]:
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states,
|
||||
dim=0,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states,
|
||||
dim=0,
|
||||
process_group=sp_group,
|
||||
grad_scale=sp_size,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
if shard_config.enable_sequence_parallelism:
|
||||
if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode):
|
||||
hidden_states = gather_sp_output(hidden_states, shard_config, sp_dim=0)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
|
|
|
@ -17,14 +17,13 @@ from transformers.models.cohere.modeling_cohere import (
|
|||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.layer._operation import (
|
||||
all_to_all_comm,
|
||||
gather_forward_split_backward,
|
||||
split_forward_gather_backward,
|
||||
)
|
||||
from colossalai.shardformer.layer._operation import all_to_all_comm, split_forward_gather_backward
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from ..layer import ColoAttention, dist_cross_entropy
|
||||
from ..layer._operation import gather_sp_output, is_share_sp_tp
|
||||
|
||||
_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring"]
|
||||
|
||||
_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]
|
||||
|
||||
|
@ -52,6 +51,7 @@ class CommandPipelineForwards:
|
|||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
force_sp_output_gather: bool = True,
|
||||
):
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
@ -93,10 +93,16 @@ class CommandPipelineForwards:
|
|||
if not isinstance(past_key_values, StaticCache):
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
past_seen_tokens = past_key_values.get_seq_length()
|
||||
|
||||
# NOTE: For generating full positions ids
|
||||
# (the states will be gathered along the seq dim before attention fwd).
|
||||
if shard_config.sequence_parallelism_mode != "ring_attn" and not stage_manager.is_first_stage():
|
||||
seq_length *= shard_config.sequence_parallel_size
|
||||
|
||||
if cache_position is None:
|
||||
if isinstance(past_key_values, StaticCache):
|
||||
raise ValueError("cache_position is a required argument when using StaticCache.")
|
||||
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=device)
|
||||
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=device)
|
||||
|
||||
seq_length_with_past = seq_length + past_seen_tokens
|
||||
|
||||
|
@ -136,7 +142,7 @@ class CommandPipelineForwards:
|
|||
)
|
||||
use_cache = False
|
||||
|
||||
if shard_config and shard_config.enable_sequence_parallelism:
|
||||
if stage_manager.is_first_stage() and shard_config.enable_sequence_parallelism:
|
||||
if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states,
|
||||
|
@ -208,23 +214,10 @@ class CommandPipelineForwards:
|
|||
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if shard_config and shard_config.enable_sequence_parallelism:
|
||||
if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
grad_scale=shard_config.sequence_parallel_size,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
sp_mode = shard_config.sequence_parallelism_mode
|
||||
if shard_config.enable_sequence_parallelism:
|
||||
if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode):
|
||||
hidden_states = gather_sp_output(hidden_states, shard_config)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
|
@ -327,6 +320,7 @@ class CommandPipelineForwards:
|
|||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config,
|
||||
force_sp_output_gather=False,
|
||||
)
|
||||
past_key_values = None
|
||||
|
||||
|
@ -335,9 +329,10 @@ class CommandPipelineForwards:
|
|||
logits = self.lm_head(hidden_states)
|
||||
logits = logits * self.logit_scale
|
||||
logits = logits.float()
|
||||
loss = dist_cross_entropy(
|
||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
|
||||
)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
@ -482,6 +477,7 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
|
|||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
force_sp_output_gather: bool = True,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
|
@ -584,14 +580,10 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
|
|||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if sp_mode == "ring" or sp_mode == "split_gather":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
# Cases that don't support parallelizing cross entropy computation along sequence
|
||||
if shard_config.enable_sequence_parallelism:
|
||||
if (not shard_config.parallel_output) or is_share_sp_tp(sp_mode) or force_sp_output_gather:
|
||||
hidden_states = gather_sp_output(hidden_states, shard_config)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
|
@ -676,6 +668,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
force_sp_output_gather=False,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
|
@ -683,14 +676,16 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||
logits = self.lm_head(hidden_states)
|
||||
logits = logits * self.logit_scale
|
||||
logits = logits.float()
|
||||
loss = dist_cross_entropy(
|
||||
labels,
|
||||
logits,
|
||||
shard_config,
|
||||
self.lm_head.out_features,
|
||||
self.config.vocab_size,
|
||||
self.model.dtype,
|
||||
)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = dist_cross_entropy(
|
||||
labels,
|
||||
logits,
|
||||
shard_config,
|
||||
self.lm_head.out_features,
|
||||
self.model.dtype,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
|
|
@ -21,8 +21,9 @@ from transformers.models.gpt2.modeling_gpt2 import (
|
|||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.layer import ColoAttention
|
||||
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
|
||||
from colossalai.shardformer.layer import ColoAttention, RingAttention
|
||||
from colossalai.shardformer.layer._operation import gather_sp_output, split_forward_gather_backward
|
||||
from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from ..layer import dist_cross_entropy
|
||||
|
@ -39,10 +40,16 @@ def _get_attention_mask(
|
|||
encoder_hidden_states: Optional[torch.Tensor],
|
||||
encoder_attention_mask: Optional[torch.FloatTensor],
|
||||
) -> Tuple[Optional[Union[torch.Tensor, dict]], Optional[Union[torch.Tensor, dict]]]:
|
||||
batch_size, seq_len = hidden_states.shape[:2]
|
||||
# Received input is already split for non-first pipeline stages,
|
||||
# but attn mask isn't
|
||||
batch_size = hidden_states.size(0)
|
||||
seq_len = attention_mask.size(-1)
|
||||
|
||||
sp_mode = shard_config.sequence_parallelism_mode
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if self.config.add_cross_attention and encoder_hidden_states is not None:
|
||||
assert not sp_mode == "ring_attn", "Ring Attention only supports decoder-only."
|
||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||
if shard_config.enable_flash_attention:
|
||||
encoder_attention_mask = ColoAttention.prepare_attn_kwargs(
|
||||
|
@ -62,6 +69,7 @@ def _get_attention_mask(
|
|||
encoder_attention_mask = {"attention_mask": None}
|
||||
else:
|
||||
encoder_attention_mask = None
|
||||
|
||||
# GPT2Attention mask.
|
||||
past_key_values_length = 0
|
||||
if past_key_values is not None and past_key_values[0] is not None:
|
||||
|
@ -69,6 +77,7 @@ def _get_attention_mask(
|
|||
if shard_config.enable_flash_attention:
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.view(batch_size, -1)
|
||||
|
||||
attention_mask = ColoAttention.prepare_attn_kwargs(
|
||||
(batch_size, 1, seq_len, seq_len + past_key_values_length),
|
||||
hidden_states.dtype,
|
||||
|
@ -123,6 +132,7 @@ class GPT2PipelineForwards:
|
|||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
force_sp_gather: Optional[bool] = True,
|
||||
) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward.
|
||||
# Please refer to original code of transformers for more details.
|
||||
|
@ -146,16 +156,15 @@ class GPT2PipelineForwards:
|
|||
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
|
||||
use_cache = False
|
||||
|
||||
if stage_manager.is_first_stage():
|
||||
disable_pp = stage_manager is None
|
||||
if disable_pp or stage_manager.is_first_stage():
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
input_ids.shape[0]
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
inputs_embeds.shape[0]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
|
@ -176,7 +185,7 @@ class GPT2PipelineForwards:
|
|||
# head_mask has shape n_layer x batch x n_heads x N x N
|
||||
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||
|
||||
if stage_manager.is_first_stage():
|
||||
if disable_pp or stage_manager.is_first_stage():
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(0, input_shape[-1], dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0)
|
||||
|
@ -190,9 +199,7 @@ class GPT2PipelineForwards:
|
|||
hidden_states = hidden_states + token_type_embeds
|
||||
hidden_states = self.drop(hidden_states)
|
||||
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
|
||||
attention_mask, encoder_attention_mask = _get_attention_mask(
|
||||
attn_kwargs, encoder_attention_mask = _get_attention_mask(
|
||||
self,
|
||||
shard_config,
|
||||
hidden_states,
|
||||
|
@ -215,23 +222,43 @@ class GPT2PipelineForwards:
|
|||
|
||||
# split the input tensor along sequence dimension
|
||||
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
||||
if shard_config and shard_config.enable_sequence_parallelism:
|
||||
if shard_config.sequence_parallelism_mode == "split_gather":
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
sp_mode = shard_config.sequence_parallelism_mode
|
||||
sp_group = shard_config.sequence_parallel_process_group
|
||||
if disable_pp or stage_manager.is_first_stage():
|
||||
# Ring Attention's special zigzag batch processing
|
||||
if sp_mode == "ring_attn":
|
||||
assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention."
|
||||
if not attention_mask.bool().all():
|
||||
hidden_states, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch(
|
||||
attention_mask, sp_group, hidden_states, position_ids
|
||||
)
|
||||
else:
|
||||
hidden_states, position_ids = split_batch_zigzag([hidden_states, position_ids], sp_group)
|
||||
# Other sp modes
|
||||
else:
|
||||
if sp_mode == "split_gather":
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
)
|
||||
elif sp_mode == "ring_attn":
|
||||
# Later stages already received split hidden states
|
||||
_, attn_kwargs, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group)
|
||||
del attention_mask
|
||||
|
||||
# Going through held blocks.
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
if disable_pp:
|
||||
start_idx, end_idx = 0, len(self.h)
|
||||
else:
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
|
||||
for i in range(start_idx, end_idx):
|
||||
block = self.h[i]
|
||||
torch.cuda.set_device(hidden_states.device)
|
||||
# Ensure that attention_mask is always on the same device as hidden_states
|
||||
if torch.is_tensor(attention_mask):
|
||||
attention_mask = attention_mask.to(hidden_states.device)
|
||||
if torch.is_tensor(attn_kwargs):
|
||||
attn_kwargs = attn_kwargs.to(hidden_states.device)
|
||||
if isinstance(head_mask, torch.Tensor):
|
||||
head_mask = head_mask.to(hidden_states.device)
|
||||
if output_hidden_states:
|
||||
|
@ -242,7 +269,7 @@ class GPT2PipelineForwards:
|
|||
block.__call__,
|
||||
hidden_states,
|
||||
None,
|
||||
attention_mask,
|
||||
attn_kwargs,
|
||||
head_mask[i],
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
|
@ -253,7 +280,7 @@ class GPT2PipelineForwards:
|
|||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=None,
|
||||
attention_mask=attention_mask,
|
||||
attention_mask=attn_kwargs,
|
||||
head_mask=head_mask[i],
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
|
@ -270,26 +297,25 @@ class GPT2PipelineForwards:
|
|||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
||||
|
||||
# When sequence parallelism done, gather the output tensor in forward and split it in backward
|
||||
if shard_config and shard_config.enable_sequence_parallelism:
|
||||
if shard_config.sequence_parallelism_mode == "split_gather":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
# When sequence parallelism is done, gather the output tensor in forward and split it in backward
|
||||
gather_output = (not shard_config.parallel_output) or force_sp_gather or is_share_sp_tp(sp_mode)
|
||||
if disable_pp or stage_manager.is_last_stage():
|
||||
if gather_output:
|
||||
hidden_states = gather_sp_output(hidden_states, shard_config)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
# gather_sp_output could've changed seq length.
|
||||
input_shape = (*input_shape[:-1], hidden_states.size(-2))
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
|
||||
if disable_pp or stage_manager.is_last_stage():
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.view(output_shape)
|
||||
|
||||
# Add last hidden state
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
if disable_pp or stage_manager.is_last_stage():
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
|
@ -366,17 +392,29 @@ class GPT2PipelineForwards:
|
|||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config,
|
||||
force_sp_gather=False,
|
||||
)
|
||||
|
||||
# If not at the last stage, return hidden_states as in GPT2Model
|
||||
if not stage_manager.is_last_stage():
|
||||
disable_pp = stage_manager is None
|
||||
if (not disable_pp) and (not stage_manager.is_last_stage()):
|
||||
return {"hidden_states": outputs["hidden_states"]}
|
||||
|
||||
hidden_states = outputs[0]
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
loss = dist_cross_entropy(
|
||||
labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype
|
||||
)
|
||||
if shard_config.sequence_parallelism_mode == "ring_attn":
|
||||
# Split labels in a zigzag fashion too
|
||||
sp_group = shard_config.sequence_parallel_process_group
|
||||
if not attention_mask.bool().all():
|
||||
# [B, max_seqlen // sp_size]
|
||||
labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True)
|
||||
else:
|
||||
labels = split_batch_zigzag(labels, sp_group, seq_dim=1, is_label=True)
|
||||
|
||||
if labels is not None:
|
||||
loss = dist_cross_entropy(
|
||||
labels, lm_logits, shard_config, self.lm_head.out_features, self.transformer.dtype
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + outputs[1:]
|
||||
|
@ -770,7 +808,7 @@ class GPT2PipelineForwards:
|
|||
)
|
||||
|
||||
|
||||
def get_gpt2_flash_attention_forward():
|
||||
def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None):
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
|
||||
|
||||
def forward(
|
||||
|
@ -817,7 +855,22 @@ def get_gpt2_flash_attention_forward():
|
|||
if self.scale_attn_by_inverse_layer_idx:
|
||||
scale /= float(self.layer_idx + 1)
|
||||
dropout_p = self.attn_dropout.p if self.training else 0.0
|
||||
attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale)
|
||||
|
||||
sp_mode = shard_config.sequence_parallelism_mode
|
||||
sp_group = shard_config.sequence_parallel_process_group
|
||||
if sp_mode == "ring_attn":
|
||||
attn_output = RingAttention.attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
sp_group,
|
||||
**attention_mask,
|
||||
dropout_p=dropout_p,
|
||||
scale=scale,
|
||||
inner_ring_size=shard_config.inner_ring_size,
|
||||
)
|
||||
else:
|
||||
attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale)
|
||||
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
||||
attn_output = self.c_proj(attn_output)
|
||||
attn_output = self.resid_dropout(attn_output)
|
||||
|
@ -828,466 +881,6 @@ def get_gpt2_flash_attention_forward():
|
|||
return forward
|
||||
|
||||
|
||||
def get_gpt_model_forward_for_flash_attn(shard_config: ShardConfig):
|
||||
def forward(
|
||||
self: GPT2Model,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
input_ids.shape[0]
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
inputs_embeds.shape[0]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
||||
if position_ids is not None:
|
||||
position_ids = position_ids.view(-1, input_shape[-1])
|
||||
|
||||
if past_key_values is None:
|
||||
past_length = 0
|
||||
past_key_values = tuple([None] * len(self.h))
|
||||
else:
|
||||
past_length = past_key_values[0][0].size(-2)
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(
|
||||
past_length,
|
||||
input_shape[-1] + past_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# head_mask has shape n_layer x batch x n_heads x N x N
|
||||
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
position_embeds = self.wpe(position_ids)
|
||||
hidden_states = inputs_embeds + position_embeds
|
||||
|
||||
if token_type_ids is not None:
|
||||
token_type_embeds = self.wte(token_type_ids)
|
||||
hidden_states = hidden_states + token_type_embeds
|
||||
|
||||
hidden_states = self.drop(hidden_states)
|
||||
|
||||
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
|
||||
|
||||
attention_mask, encoder_attention_mask = _get_attention_mask(
|
||||
self,
|
||||
shard_config,
|
||||
hidden_states,
|
||||
past_key_values,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
presents = () if use_cache else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
# Model parallel
|
||||
if self.model_parallel:
|
||||
torch.cuda.set_device(hidden_states.device)
|
||||
# Ensure layer_past is on same device as hidden_states (might not be correct)
|
||||
if layer_past is not None:
|
||||
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
|
||||
# Ensure that attention_mask is always on the same device as hidden_states
|
||||
if torch.is_tensor(attention_mask):
|
||||
attention_mask = attention_mask.to(hidden_states.device)
|
||||
if isinstance(head_mask, torch.Tensor):
|
||||
head_mask = head_mask.to(hidden_states.device)
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, use_cache, output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
None,
|
||||
attention_mask,
|
||||
head_mask[i],
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
else:
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask[i],
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if use_cache is True:
|
||||
presents = presents + (outputs[1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
||||
|
||||
# Model Parallel: If it's the last layer for that device, put things on the next device
|
||||
if self.model_parallel:
|
||||
for k, v in self.device_map.items():
|
||||
if i == v[-1] and "cuda:" + str(k) != self.last_device:
|
||||
hidden_states = hidden_states.to("cuda:" + str(k + 1))
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.view(output_shape)
|
||||
# Add last hidden state
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
presents,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
all_cross_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=presents,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
input_ids.shape[0]
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
inputs_embeds.shape[0]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
||||
if position_ids is not None:
|
||||
position_ids = position_ids.view(-1, input_shape[-1])
|
||||
|
||||
if past_key_values is None:
|
||||
past_length = 0
|
||||
past_key_values = tuple([None] * len(self.h))
|
||||
else:
|
||||
past_length = past_key_values[0][0].size(-2)
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(
|
||||
past_length,
|
||||
input_shape[-1] + past_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# head_mask has shape n_layer x batch x n_heads x N x N
|
||||
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
position_embeds = self.wpe(position_ids)
|
||||
hidden_states = inputs_embeds + position_embeds
|
||||
|
||||
if token_type_ids is not None:
|
||||
token_type_embeds = self.wte(token_type_ids)
|
||||
hidden_states = hidden_states + token_type_embeds
|
||||
|
||||
hidden_states = self.drop(hidden_states)
|
||||
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
attention_mask, encoder_attention_mask = _get_attention_mask(
|
||||
self,
|
||||
shard_config,
|
||||
hidden_states,
|
||||
past_key_values,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger = logging.get_logger(__name__)
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
presents = () if use_cache else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
|
||||
# split the input tensor along sequence dimension
|
||||
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
# Model parallel
|
||||
if self.model_parallel:
|
||||
torch.cuda.set_device(hidden_states.device)
|
||||
# Ensure layer_past is on same device as hidden_states (might not be correct)
|
||||
if layer_past is not None:
|
||||
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
|
||||
# Ensure that attention_mask is always on the same device as hidden_states
|
||||
if torch.is_tensor(attention_mask):
|
||||
attention_mask = attention_mask.to(hidden_states.device)
|
||||
if isinstance(head_mask, torch.Tensor):
|
||||
head_mask = head_mask.to(hidden_states.device)
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, use_cache, output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
None,
|
||||
attention_mask,
|
||||
head_mask[i],
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
else:
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask[i],
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if use_cache is True:
|
||||
presents = presents + (outputs[1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
||||
|
||||
# Model Parallel: If it's the last layer for that device, put things on the next device
|
||||
if self.model_parallel:
|
||||
for k, v in self.device_map.items():
|
||||
if i == v[-1] and "cuda:" + str(k) != self.last_device:
|
||||
hidden_states = hidden_states.to("cuda:" + str(k + 1))
|
||||
|
||||
# When sequence parallelism done, gather the output tensor in forward and split it in backward
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
hidden_states = hidden_states.view(output_shape)
|
||||
# Add last hidden state
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
presents,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
all_cross_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=presents,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||
from transformers import GPT2LMHeadModel
|
||||
|
||||
def forward(
|
||||
self: GPT2LMHeadModel,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
||||
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
||||
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
hidden_states = transformer_outputs[0]
|
||||
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
loss = dist_cross_entropy(
|
||||
labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithCrossAttentions(
|
||||
loss=loss,
|
||||
logits=lm_logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
cross_attentions=transformer_outputs.cross_attentions,
|
||||
)
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_jit_fused_gpt2_mlp_forward():
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP
|
||||
|
||||
|
|
|
@ -25,7 +25,6 @@ from transformers.models.llama.modeling_llama import (
|
|||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.layer import AttnMaskType
|
||||
from colossalai.shardformer.layer._operation import all_to_all_comm, gather_sp_output, split_forward_gather_backward
|
||||
from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
@ -58,10 +57,7 @@ class LlamaPipelineForwards:
|
|||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
# Split output only when computing cross entropy using llama_for_causal_lm_forward
|
||||
# or get_lm_forward_with_dist_cross_entropy
|
||||
# Default to True to avoid bug when calling classification forward from huggingface
|
||||
force_sp_output_gather: bool = True,
|
||||
force_sp_gather: bool = True, # Set to false only when computing cross entropy
|
||||
):
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
@ -78,8 +74,9 @@ class LlamaPipelineForwards:
|
|||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
disable_pp = stage_manager is None
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if stage_manager.is_first_stage():
|
||||
if disable_pp or stage_manager.is_first_stage():
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
|
@ -88,10 +85,10 @@ class LlamaPipelineForwards:
|
|||
batch_size, seq_length, _ = inputs_embeds.shape[:2]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = inputs_embeds
|
||||
device = hidden_states.device
|
||||
else:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
batch_size, seq_length = input_shape
|
||||
|
@ -101,8 +98,8 @@ class LlamaPipelineForwards:
|
|||
sp_mode = shard_config.sequence_parallelism_mode
|
||||
sp_group = shard_config.sequence_parallel_process_group
|
||||
sp_size = shard_config.sequence_parallel_size
|
||||
if sp_mode == "all_to_all" and not stage_manager.is_first_stage():
|
||||
# For generating full positions ids, as the states will be gather along the seq dim in the attention layer later.
|
||||
# Generating full positions ids for modes that gather sequence before attn
|
||||
if stage_manager and (sp_mode != "ring_attn" and not stage_manager.is_first_stage()):
|
||||
seq_length *= sp_size
|
||||
|
||||
past_seen_tokens = 0
|
||||
|
@ -117,7 +114,6 @@ class LlamaPipelineForwards:
|
|||
|
||||
seq_length_with_past = seq_length + past_seen_tokens
|
||||
|
||||
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||
if output_attentions:
|
||||
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||
output_attentions = False
|
||||
|
@ -130,14 +126,13 @@ class LlamaPipelineForwards:
|
|||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
# embed positions, for the first stage, hidden_states is the input embeddings,
|
||||
# for the other stages, hidden_states is the output of the previous stage
|
||||
if not stage_manager.is_first_stage() and sp_mode == "ring_attn":
|
||||
|
||||
no_split_input = disable_pp or not stage_manager.is_first_stage()
|
||||
if no_split_input and sp_mode == "ring_attn":
|
||||
_, attn_kwargs, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group)
|
||||
elif shard_config.enable_flash_attention:
|
||||
# in this case, attention_mask is a dict rather than a tensor
|
||||
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
|
||||
attn_kwargs = ColoAttention.prepare_attn_kwargs(
|
||||
attn_kwargs: dict = ColoAttention.prepare_attn_kwargs(
|
||||
mask_shape,
|
||||
hidden_states.dtype,
|
||||
hidden_states.device,
|
||||
|
@ -146,15 +141,15 @@ class LlamaPipelineForwards:
|
|||
invert=(sp_mode != "ring_attn"),
|
||||
)
|
||||
else:
|
||||
attn_kwargs = self._update_causal_mask(attention_mask, hidden_states, cache_position)
|
||||
attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, hidden_states, cache_position)
|
||||
|
||||
# Support SP + PP
|
||||
# TODO: support padded casual cu_seqlens across stages
|
||||
if stage_manager.is_first_stage():
|
||||
# Support SP + PP. Later stages have already received the split input.
|
||||
split_input = disable_pp or stage_manager.is_first_stage()
|
||||
if split_input:
|
||||
# Ring Attention zigzag batch processing
|
||||
if sp_mode == "ring_attn":
|
||||
assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention."
|
||||
if attn_kwargs["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL:
|
||||
if not attention_mask.bool().all():
|
||||
hidden_states, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch(
|
||||
attention_mask, sp_group, hidden_states, position_ids
|
||||
)
|
||||
|
@ -181,8 +176,8 @@ class LlamaPipelineForwards:
|
|||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
start_idx, end_idx = (0, len(self.layers)) if disable_pp else (stage_index[0], stage_index[1])
|
||||
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
num_ckpt_layers = 0
|
||||
if self.gradient_checkpointing and self.training:
|
||||
num_ckpt_layers = end_idx - start_idx
|
||||
|
@ -228,18 +223,16 @@ class LlamaPipelineForwards:
|
|||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
if disable_pp or stage_manager.is_last_stage():
|
||||
hidden_states = self.norm(hidden_states)
|
||||
if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode):
|
||||
hidden_states = gather_sp_output(
|
||||
hidden_states, sp_group, sp_mode, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
if (not shard_config.parallel_output) or force_sp_gather or is_share_sp_tp(sp_mode): # noqa
|
||||
hidden_states = gather_sp_output(hidden_states, shard_config)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if stage_manager.is_last_stage():
|
||||
if disable_pp or stage_manager.is_last_stage():
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
|
@ -257,7 +250,7 @@ class LlamaPipelineForwards:
|
|||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
# always return dict for imediate stage
|
||||
# always return dict for intermediate stage
|
||||
return {"hidden_states": hidden_states}
|
||||
|
||||
@staticmethod
|
||||
|
@ -323,7 +316,7 @@ class LlamaPipelineForwards:
|
|||
# Split labels in a zigzag fashion too
|
||||
sp_group = shard_config.sequence_parallel_process_group
|
||||
if attention_mask.bool().all():
|
||||
labels = split_batch_zigzag(labels, sp_group, seq_dim=1)
|
||||
labels = split_batch_zigzag(labels, sp_group, seq_dim=1, is_label=True)
|
||||
else:
|
||||
# [B, max_seqlen // sp_size]
|
||||
labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True)
|
||||
|
@ -345,16 +338,17 @@ class LlamaPipelineForwards:
|
|||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config,
|
||||
force_sp_output_gather=False,
|
||||
force_sp_gather=False,
|
||||
)
|
||||
past_key_values = None
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
disable_pp = stage_manager is None
|
||||
if disable_pp or stage_manager.is_last_stage():
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
loss = dist_cross_entropy(
|
||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
|
||||
)
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
@ -629,263 +623,3 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
|||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
# Split output only when computing cross entropy using llama_for_causal_lm_forward
|
||||
# or get_lm_forward_with_dist_cross_entropy
|
||||
# Default to True to avoid bug when calling classification forward from huggingface
|
||||
force_sp_output_gather: bool = True,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError(
|
||||
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
past_seen_tokens = 0
|
||||
seq_len = inputs_embeds.shape[1]
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
if use_cache: # kept for BC (cache positions)
|
||||
if not isinstance(past_key_values, StaticCache):
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
past_seen_tokens = past_key_values.get_seq_length()
|
||||
|
||||
if cache_position is None:
|
||||
if isinstance(past_key_values, StaticCache):
|
||||
raise ValueError("cache_position is a required argument when using StaticCache.")
|
||||
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_len, device=inputs_embeds.device)
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
if shard_config.enable_flash_attention:
|
||||
mask_shape = (batch_size, 1, seq_len, past_seen_tokens + seq_len)
|
||||
attn_kwargs: dict = ColoAttention.prepare_attn_kwargs(
|
||||
mask_shape,
|
||||
inputs_embeds.dtype,
|
||||
inputs_embeds.device,
|
||||
q_padding_mask=attention_mask,
|
||||
is_causal=True,
|
||||
invert=(sp_mode != "ring_attn"),
|
||||
)
|
||||
|
||||
else:
|
||||
attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
|
||||
|
||||
# Ring Attention zigzag batch processing
|
||||
if sp_mode == "ring_attn":
|
||||
assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention."
|
||||
if attn_kwargs["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL:
|
||||
inputs_embeds, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch(
|
||||
attention_mask, sp_group, inputs_embeds, position_ids
|
||||
)
|
||||
else:
|
||||
inputs_embeds, position_ids = split_batch_zigzag([inputs_embeds, position_ids], sp_group)
|
||||
attn_kwargs = {"attention_mask_type": attn_kwargs["attention_mask_type"]} # drop redundant tensors
|
||||
|
||||
elif is_share_sp_tp(sp_mode):
|
||||
inputs_embeds = split_forward_gather_backward(
|
||||
inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
inputs_embeds = split_forward_gather_backward(
|
||||
inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attn_kwargs,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
)
|
||||
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attn_kwargs,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
# Cases that don't support parallelizing cross entropy computation along sequence
|
||||
if (not shard_config.parallel_output) or is_share_sp_tp(sp_mode) or force_sp_output_gather:
|
||||
hidden_states = gather_sp_output(
|
||||
hidden_states, sp_group, sp_mode, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = None
|
||||
if use_cache:
|
||||
next_cache = (
|
||||
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
|
||||
)
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||
from transformers import LlamaForCausalLM
|
||||
|
||||
def forward(
|
||||
self: LlamaForCausalLM,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
||||
|
||||
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if shard_config.sequence_parallelism_mode == "ring_attn" and shard_config.parallel_output:
|
||||
# Special processing: Split labels in a zigzag fashion too
|
||||
sp_group = shard_config.sequence_parallel_process_group
|
||||
if attention_mask.bool().all():
|
||||
labels = split_batch_zigzag(labels, sp_group, seq_dim=1, is_label=True)
|
||||
else:
|
||||
# [B, max_seq_len // sp_size]
|
||||
labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
force_sp_output_gather=False,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if self.config.pretraining_tp > 1:
|
||||
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
|
||||
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
logits = torch.cat(logits, dim=-1)
|
||||
else:
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
loss = dist_cross_entropy(
|
||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
|
||||
)
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
return forward
|
||||
|
|
|
@ -274,10 +274,9 @@ class MistralForwards:
|
|||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
|
||||
loss = dist_cross_entropy(
|
||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
|
||||
)
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
@ -687,10 +686,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
|
||||
loss = dist_cross_entropy(
|
||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
|
||||
)
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
|
|
@ -330,14 +330,15 @@ class OPTPipelineForwards:
|
|||
)
|
||||
if stage_manager.is_last_stage():
|
||||
logits = self.lm_head(outputs[0]).contiguous()
|
||||
loss = dist_cross_entropy(
|
||||
labels,
|
||||
logits,
|
||||
shard_config,
|
||||
self.lm_head.out_features,
|
||||
self.config.vocab_size,
|
||||
self.model.decoder.dtype,
|
||||
)
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = dist_cross_entropy(
|
||||
labels,
|
||||
logits,
|
||||
shard_config,
|
||||
self.lm_head.out_features,
|
||||
self.model.decoder.dtype,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
@ -955,9 +956,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||
)
|
||||
|
||||
logits = self.lm_head(outputs[0]).contiguous()
|
||||
loss = dist_cross_entropy(
|
||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.decoder.dtype
|
||||
)
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.decoder.dtype)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
|
|
@ -32,14 +32,12 @@ except ImportError:
|
|||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.layer._operation import (
|
||||
all_to_all_comm,
|
||||
gather_forward_split_backward,
|
||||
split_forward_gather_backward,
|
||||
)
|
||||
from colossalai.shardformer.layer._operation import all_to_all_comm, split_forward_gather_backward
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from ..layer import ColoAttention, dist_cross_entropy
|
||||
from ..layer._operation import gather_sp_output
|
||||
from ..layer.utils import is_share_sp_tp
|
||||
|
||||
|
||||
class Qwen2PipelineForwards:
|
||||
|
@ -64,6 +62,7 @@ class Qwen2PipelineForwards:
|
|||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
force_sp_output_gather: bool = True,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
@ -115,6 +114,14 @@ class Qwen2PipelineForwards:
|
|||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
# Support SP + PP
|
||||
sp_size = shard_config.sequence_parallel_size
|
||||
sp_group = shard_config.sequence_parallel_process_group
|
||||
sp_mode = shard_config.sequence_parallelism_mode
|
||||
# For generating full positions ids (the states will be gathered along the seq dim before attention fwd).
|
||||
if sp_mode != "ring_attn" and not stage_manager.is_first_stage():
|
||||
seq_length *= sp_size
|
||||
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
|
@ -151,7 +158,6 @@ class Qwen2PipelineForwards:
|
|||
elif self._attn_implementation == "sdpa" and not output_attentions:
|
||||
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
||||
# the manual implementation that requires a 4D causal mask in all cases.
|
||||
|
||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
|
@ -160,7 +166,6 @@ class Qwen2PipelineForwards:
|
|||
)
|
||||
else:
|
||||
# 4d mask is passed through the layers
|
||||
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
|
@ -169,22 +174,21 @@ class Qwen2PipelineForwards:
|
|||
sliding_window=self.config.sliding_window,
|
||||
)
|
||||
|
||||
if shard_config and shard_config.enable_sequence_parallelism:
|
||||
if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
grad_scale=1 / shard_config.sequence_parallel_size,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
if stage_manager.is_first_stage():
|
||||
if shard_config.enable_sequence_parallelism:
|
||||
if is_share_sp_tp(sp_mode):
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=sp_group,
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=sp_group,
|
||||
grad_scale=1 / sp_size,
|
||||
)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
|
@ -241,23 +245,10 @@ class Qwen2PipelineForwards:
|
|||
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = self.norm(hidden_states)
|
||||
if shard_config.enable_sequence_parallelism:
|
||||
if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode):
|
||||
hidden_states = gather_sp_output(hidden_states, shard_config)
|
||||
|
||||
if shard_config and shard_config.enable_sequence_parallelism:
|
||||
if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
grad_scale=shard_config.sequence_parallel_size,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
@ -351,15 +342,18 @@ class Qwen2PipelineForwards:
|
|||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config,
|
||||
force_sp_output_gather=False,
|
||||
)
|
||||
past_key_values = None
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = outputs[0]
|
||||
if hidden_states.shape[1] == 2:
|
||||
pass
|
||||
logits = self.lm_head(hidden_states)
|
||||
loss = dist_cross_entropy(
|
||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, logits.dtype
|
||||
)
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, logits.dtype)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
@ -541,7 +535,6 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
|||
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
||||
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
|
||||
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
if past_key_value is not None:
|
||||
|
@ -635,6 +628,7 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No
|
|||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
force_sp_output_gather: bool = True,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
|
@ -750,14 +744,9 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No
|
|||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if sp_mode == "ring" or sp_mode == "split_gather":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
if shard_config.enable_sequence_parallelism:
|
||||
if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode):
|
||||
hidden_states = gather_sp_output(hidden_states, shard_config)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
|
@ -834,14 +823,15 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
force_sp_output_gather=False,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
loss = dist_cross_entropy(
|
||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, logits.dtype
|
||||
)
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, logits.dtype)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
|
|
@ -64,7 +64,7 @@ class ChatGLMPolicy(Policy):
|
|||
|
||||
if sp_mode == "ring":
|
||||
warnings.warn(
|
||||
f"For ChatGLM2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather"
|
||||
f"For ChatGLM2, sequence parallelism doesn't support mode {sp_mode} yet, will set to be split_gather"
|
||||
)
|
||||
sp_mode = "split_gather"
|
||||
overlap = self.shard_config.enable_sequence_overlap
|
||||
|
|
|
@ -6,14 +6,7 @@ from torch import Tensor, nn
|
|||
|
||||
import colossalai.shardformer.layer as col_nn
|
||||
|
||||
from ..modeling.gpt2 import (
|
||||
GPT2PipelineForwards,
|
||||
get_gpt2_flash_attention_forward,
|
||||
get_gpt_model_forward_for_flash_attn,
|
||||
get_jit_fused_gpt2_mlp_forward,
|
||||
get_lm_forward_with_dist_cross_entropy,
|
||||
gpt2_sequence_parallel_forward_fn,
|
||||
)
|
||||
from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward, get_jit_fused_gpt2_mlp_forward
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = [
|
||||
|
@ -71,18 +64,10 @@ class GPT2Policy(Policy):
|
|||
warnings.warn(
|
||||
f"For GPT2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather"
|
||||
)
|
||||
sp_mode = "split_gather"
|
||||
self.shard_config.sequence_parallelism_mode = sp_mode = "split_gather"
|
||||
overlap = self.shard_config.enable_sequence_overlap
|
||||
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
||||
use_flash_attention = self.shard_config.enable_flash_attention
|
||||
# todo: currently sp cannot be used with flashattention
|
||||
if sp_mode in ["split_gather", "ring", "all_to_all"]:
|
||||
if use_flash_attention:
|
||||
warnings.warn(
|
||||
f"Sequence parallelism mode {sp_mode} cannot be used with FlashAttention, will disable FlashAttention automatically."
|
||||
)
|
||||
self.shard_config.enable_flash_attention = False
|
||||
use_flash_attention = False
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
assert (
|
||||
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||
|
@ -211,18 +196,16 @@ class GPT2Policy(Policy):
|
|||
if use_flash_attention:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_gpt2_flash_attention_forward(),
|
||||
"forward": get_gpt2_flash_attention_forward(shard_config=self.shard_config),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=attn_cls,
|
||||
)
|
||||
if not self.shard_config.pipeline_stage_manager:
|
||||
policy[GPT2Model].method_replacement = {
|
||||
"forward": get_gpt_model_forward_for_flash_attn(self.shard_config)
|
||||
}
|
||||
|
||||
if sp_mode is not None:
|
||||
policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)}
|
||||
if not self.shard_config.pipeline_stage_manager and self.shard_config.enable_sequence_parallelism:
|
||||
policy[GPT2Model].method_replacement = {
|
||||
"forward": partial(GPT2PipelineForwards.gpt2_model_forward, shard_config=self.shard_config)
|
||||
}
|
||||
|
||||
return policy
|
||||
|
||||
|
@ -328,40 +311,39 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
|
|||
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
||||
|
||||
module_policy = super().module_policy()
|
||||
|
||||
module_policy[GPT2LMHeadModel] = ModulePolicyDescription()
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
addon_module = {
|
||||
GPT2LMHeadModel: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=col_nn.VocabParallelLMHead1D,
|
||||
kwargs={
|
||||
"gather_output": False,
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
}
|
||||
if self.shard_config.parallel_output:
|
||||
addon_module[GPT2LMHeadModel].method_replacement = {
|
||||
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
|
||||
}
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=col_nn.VocabParallelLMHead1D,
|
||||
kwargs={
|
||||
"gather_output": False,
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
},
|
||||
),
|
||||
policy=module_policy,
|
||||
target_key=GPT2LMHeadModel,
|
||||
)
|
||||
else:
|
||||
addon_module = {
|
||||
GPT2LMHeadModel: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=col_nn.PaddingLMHead,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
)
|
||||
]
|
||||
)
|
||||
}
|
||||
module_policy.update(addon_module)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=col_nn.PaddingLMHead,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
),
|
||||
policy=module_policy,
|
||||
target_key=GPT2LMHeadModel,
|
||||
)
|
||||
|
||||
if self.shard_config.parallel_output:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": partial(GPT2PipelineForwards.gpt2_lmhead_model_forward, shard_config=self.shard_config)
|
||||
},
|
||||
policy=module_policy,
|
||||
target_key=GPT2LMHeadModel,
|
||||
)
|
||||
|
||||
if self.pipeline_stage_manager is not None:
|
||||
self.set_pipeline_forward(
|
||||
|
|
|
@ -16,12 +16,7 @@ from colossalai.shardformer.layer import (
|
|||
VocabParallelLMHead1D,
|
||||
)
|
||||
|
||||
from ..modeling.llama import (
|
||||
LlamaPipelineForwards,
|
||||
get_llama_flash_attention_forward,
|
||||
get_llama_flash_attention_model_forward,
|
||||
get_lm_forward_with_dist_cross_entropy,
|
||||
)
|
||||
from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ["LlamaPolicy", "LlamaForCausalLMPolicy", "LlamaForSequenceClassificationPolicy"]
|
||||
|
@ -99,11 +94,9 @@ class LlamaPolicy(Policy):
|
|||
if self.pipeline_stage_manager is None:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_llama_flash_attention_model_forward(
|
||||
self.shard_config,
|
||||
sp_mode=sp_mode,
|
||||
sp_size=sp_size,
|
||||
sp_group=sp_group,
|
||||
"forward": partial(
|
||||
LlamaPipelineForwards.llama_model_forward,
|
||||
shard_config=self.shard_config,
|
||||
),
|
||||
},
|
||||
policy=policy,
|
||||
|
@ -351,7 +344,8 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
|||
elif self.shard_config.enable_tensor_parallelism or self.shard_config.enable_sequence_parallelism:
|
||||
# Compute loss distributedly along the sequence dimension
|
||||
new_item[LlamaForCausalLM].method_replacement = {
|
||||
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
|
||||
# "forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
|
||||
"forward": partial(LlamaPipelineForwards.llama_for_causal_lm_forward, shard_config=self.shard_config)
|
||||
}
|
||||
return policy
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ MODEL_CONFIGS = {
|
|||
"118M": GPT2Config(activation_function="gelu"),
|
||||
"338M": GPT2Config(n_embd=1024, n_head=16, n_layer=24, activation_function="gelu"),
|
||||
"738M": GPT2Config(n_embd=1280, n_head=20, n_layer=36, activation_function="gelu"),
|
||||
"6.21B": GPT2Config(n_embd=4096, n_head=32, n_layer=32, n_positions=4096, activation_function="gelu"),
|
||||
"6.21B": GPT2Config(n_embd=4096, n_head=32, n_layer=32, n_positions=32768, activation_function="gelu"),
|
||||
}
|
||||
|
||||
|
||||
|
@ -60,6 +60,8 @@ def main():
|
|||
parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size")
|
||||
parser.add_argument("--extra_dp", type=int, default=1, help="Extra data parallel size, used for Gemini")
|
||||
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size")
|
||||
parser.add_argument("--sp", type=int, default=1, help="Sequence parallel size")
|
||||
parser.add_argument("--sp_mode", type=str, default="ring_attn", help="Sequence parallel mode")
|
||||
parser.add_argument("--mbs", type=int, default=1)
|
||||
parser.add_argument("--zero", type=int, default=0)
|
||||
parser.add_argument("--pp_style", type=str, default="1f1b")
|
||||
|
@ -129,6 +131,9 @@ def main():
|
|||
tp_size=args.tp,
|
||||
pp_size=args.pp,
|
||||
pp_style=args.pp_style,
|
||||
sp_size=args.sp,
|
||||
sequence_parallelism_mode=args.sp_mode,
|
||||
enable_sequence_parallelism=True,
|
||||
zero_stage=args.zero,
|
||||
num_model_chunks=args.num_model_chunks,
|
||||
enable_all_optimization=True,
|
||||
|
@ -214,6 +219,8 @@ def main():
|
|||
performance_evaluator.on_step_start(step)
|
||||
outputs = model(**batch)
|
||||
loss = outputs[0]
|
||||
del outputs
|
||||
|
||||
booster.backward(loss, optimizer)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
|
|
@ -6,7 +6,6 @@ import torch.distributed as dist
|
|||
from torch import Tensor
|
||||
from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.cluster import DistCoordinator
|
||||
|
||||
|
||||
|
@ -22,8 +21,11 @@ def divide(x: float, y: float) -> float:
|
|||
def all_reduce_mean(x: float, world_size: int) -> float:
|
||||
if world_size == 1:
|
||||
return x
|
||||
tensor = torch.tensor([x], device=get_accelerator().get_current_device())
|
||||
dist.all_reduce(tensor)
|
||||
|
||||
# Use CPU tensor to avoid OOM/weird NCCl error
|
||||
gloo_group = dist.new_group(backend="gloo")
|
||||
tensor = torch.tensor([x], device="cpu")
|
||||
dist.all_reduce(tensor, group=gloo_group)
|
||||
tensor = tensor / world_size
|
||||
return tensor.item()
|
||||
|
||||
|
|
|
@ -27,7 +27,16 @@ def data_gen_for_lm():
|
|||
# LM data gen
|
||||
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
|
||||
data = data_gen()
|
||||
data["labels"] = data["input_ids"].clone()
|
||||
|
||||
# Test padded sequence for Ring Attention
|
||||
padding = torch.zeros(1, data["input_ids"].shape[1] // 2, dtype=torch.long)
|
||||
data["input_ids"] = torch.cat([data["input_ids"], padding], dim=1)
|
||||
data["attention_mask"] = torch.cat([data["attention_mask"], padding], dim=1)
|
||||
|
||||
ignore_idx = -100
|
||||
labels = data["input_ids"].clone()
|
||||
labels[~data["attention_mask"].bool()] = ignore_idx
|
||||
data["labels"] = labels
|
||||
return data
|
||||
|
||||
|
||||
|
|
|
@ -157,7 +157,6 @@ def build_model_from_hybrid_plugin(
|
|||
sharded_optimizer = sharded_optim_class(sharded_model.parameters(), lr=1e-3)
|
||||
|
||||
criterion = loss_fn
|
||||
|
||||
plugin = pluggin_cls(**test_config)
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
|
@ -323,7 +322,6 @@ def check_output_hidden_state(
|
|||
sp_size = shard_config.sequence_parallel_size
|
||||
if org_hidden_state.shape[seq_dim] == sharded_hidden_state.shape[seq_dim] * sp_size:
|
||||
org_hidden_state = org_hidden_state.chunk(sp_size, dim=seq_dim)[dist.get_rank(sp_group)]
|
||||
|
||||
assert_close(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
|
|
|
@ -136,6 +136,18 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 2,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "split_gather",
|
||||
"enable_flash_attention": True,
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 1,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{ # Ulysess + Flash attention
|
||||
"tp_size": 1,
|
||||
"pp_size": 2,
|
||||
|
@ -149,19 +161,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"sp_size": 2,
|
||||
"num_microbatches": 2,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "split_gather",
|
||||
"enable_flash_attention": True,
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 1,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 1,
|
||||
"pp_size": 1,
|
||||
|
@ -174,17 +173,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 4,
|
||||
"pp_size": 1,
|
||||
"num_microbatches": 1,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "ring",
|
||||
"enable_flash_attention": False,
|
||||
"use_lazy_init": True,
|
||||
"precision": "fp32",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 4,
|
||||
"pp_size": 1,
|
||||
|
@ -248,7 +236,11 @@ def run_chatglm_test(test_config):
|
|||
loss_fn,
|
||||
_,
|
||||
) in sub_model_zoo.items():
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
try:
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
except Exception as e:
|
||||
print(f"Test config failed for model {name}: {test_config}")
|
||||
raise e
|
||||
|
||||
clear_layout_converter()
|
||||
torch.cuda.empty_cache()
|
||||
|
|
|
@ -125,7 +125,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
atol, rtol = 5e-3, 5e-3
|
||||
|
||||
if org_model.__class__.__name__ == "CohereModel":
|
||||
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
|
||||
check_output_hidden_state(
|
||||
org_output,
|
||||
sharded_output,
|
||||
stage_manager,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
shard_config=booster.plugin.shard_config,
|
||||
)
|
||||
|
||||
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||
|
||||
|
@ -274,7 +281,11 @@ def run_command_test(test_config):
|
|||
sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_causal_lm")
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
try:
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
except Exception as e:
|
||||
print(f"Failed test config: {test_config}")
|
||||
raise e
|
||||
|
||||
clear_layout_converter()
|
||||
Randomizer.reset_index()
|
||||
|
|
|
@ -100,7 +100,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
atol, rtol = 5e-3, 5e-3
|
||||
|
||||
if org_model.__class__.__name__ == "GPT2Model":
|
||||
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
|
||||
check_output_hidden_state(
|
||||
org_output,
|
||||
sharded_output,
|
||||
stage_manager,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
shard_config=booster.plugin.shard_config,
|
||||
)
|
||||
|
||||
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||
|
||||
|
@ -132,14 +139,27 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
"test_config",
|
||||
[
|
||||
{
|
||||
"tp_size": 4,
|
||||
"pp_size": 1,
|
||||
"num_microbatches": 1,
|
||||
"sp_size": 2,
|
||||
"tp_size": 1,
|
||||
"pp_size": 2,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "ring",
|
||||
"enable_flash_attention": False,
|
||||
"sequence_parallelism_mode": "ring_attn",
|
||||
"num_microbatches": 2,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": True,
|
||||
"precision": "fp32",
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"sp_size": 2,
|
||||
"tp_size": 2,
|
||||
"pp_size": 1,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "ring_attn",
|
||||
"num_microbatches": 1,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": True,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
|
@ -148,7 +168,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
"num_microbatches": 1,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "split_gather",
|
||||
"enable_flash_attention": False,
|
||||
"enable_flash_attention": True,
|
||||
"use_lazy_init": True,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
|
@ -156,7 +176,18 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"num_microbatches": 2,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "split_gather",
|
||||
"enable_flash_attention": True,
|
||||
"use_lazy_init": True,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 2,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": True,
|
||||
"precision": "fp16",
|
||||
|
@ -185,7 +216,16 @@ def run_gpt2_test(test_config):
|
|||
loss_fn,
|
||||
_,
|
||||
) in sub_model_zoo.items():
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
|
||||
if test_config.get("sequence_parallelism_mode", None) == "ring_attn" and name != "transformers_gpt_lm":
|
||||
# Only wrote zigzag splitting for cross entropy loss
|
||||
continue
|
||||
|
||||
try:
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
except Exception as e:
|
||||
print(f"Failed config: {test_config} for model {name}")
|
||||
raise (e)
|
||||
|
||||
clear_layout_converter()
|
||||
torch.cuda.empty_cache()
|
||||
|
@ -226,7 +266,11 @@ def run_gpt2_3d_test(test_config):
|
|||
loss_fn,
|
||||
_,
|
||||
) in sub_model_zoo.items():
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
try:
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
except Exception as e:
|
||||
print(f"Failed config: {test_config} for model {name}")
|
||||
raise (e)
|
||||
|
||||
clear_layout_converter()
|
||||
torch.cuda.empty_cache()
|
||||
|
|
|
@ -165,7 +165,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
"zero_stage": 0,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
"inner_ring_size": 2,
|
||||
},
|
||||
# Ring Attention + PP
|
||||
{
|
||||
|
@ -215,18 +214,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
"sequence_parallelism_mode": "all_to_all",
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 0,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 4,
|
||||
"pp_size": 1,
|
||||
"num_microbatches": 1,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "split_gather",
|
||||
"enable_flash_attention": True,
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 1,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
|
@ -294,6 +282,7 @@ def run_llama_test(test_config):
|
|||
except Exception as e:
|
||||
print(f"Failed config: {test_config}, model name: {name}")
|
||||
raise e
|
||||
|
||||
clear_layout_converter()
|
||||
Randomizer.reset_index()
|
||||
torch.cuda.empty_cache()
|
||||
|
|
|
@ -94,6 +94,32 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"sp_size": 2,
|
||||
"num_microbatches": 2,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "split_gather",
|
||||
"enable_flash_attention": True,
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 1,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{ # Ulysess + Flash attention
|
||||
"tp_size": 1,
|
||||
"pp_size": 2,
|
||||
"sp_size": 2,
|
||||
"num_microbatches": 2,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "all_to_all",
|
||||
"enable_flash_attention": True,
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 1,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
|
@ -135,32 +161,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{ # Ulysess + Flash attention
|
||||
"tp_size": 1,
|
||||
"pp_size": 2,
|
||||
"sp_size": 2,
|
||||
"num_microbatches": 2,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "all_to_all",
|
||||
"enable_flash_attention": True,
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 1,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"sp_size": 2,
|
||||
"num_microbatches": 2,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "split_gather",
|
||||
"enable_flash_attention": True,
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 1,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
|
|
Loading…
Reference in New Issue