mirror of https://github.com/hpcaitech/ColossalAI
[Feature] Split cross-entropy computation in SP (#5959)
* halfway * fix cross-PP-stage position id length diff bug * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * unified cross entropy func for all shardformer models * remove redundant lines * add basic ring attn; debug cross entropy * fwd bwd logic complete * fwd bwd logic complete; add experimental triton rescale * precision tests passed * precision tests passed * fix typos and remove misc files * update softmax_lse shape by new interface * change tester name * remove buffer clone; support packed seq layout * add varlen tests * fix typo * all tests passed * add dkv_group; fix mask * remove debug statements * adapt chatglm, command-R, qwen * debug * halfway * fix cross-PP-stage position id length diff bug * fix typo * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * unified cross entropy func for all shardformer models * remove redundant lines * add basic ring attn; debug cross entropy * fwd bwd logic complete * fwd bwd logic complete; add experimental triton rescale * precision tests passed * precision tests passed * fix typos and remove misc files * add sp_mode to benchmark; fix varlen interface * update softmax_lse shape by new interface * add varlen tests * fix typo * all tests passed * add dkv_group; fix mask * remove debug statements * add comments * q1 index only once * remove events to simplify stream sync * simplify forward/backward logic * 2d ring forward passed * 2d ring backward passed * fixes * fix ring attn loss * 2D ring backward + llama passed * merge * update logger * fix typo * rebase * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typo * remove typos * fixes * support GPT --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/6053/head
parent
b3db1058ec
commit
8fd25d6e09
|
@ -1097,13 +1097,19 @@ def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1, fp8
|
||||||
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim, fp8_communication)
|
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim, fp8_communication)
|
||||||
|
|
||||||
|
|
||||||
def gather_sp_output(hidden_states, sp_group, sp_mode, fp8_communication=False):
|
def gather_sp_output(hidden_states, shard_config, sp_dim=1):
|
||||||
"""
|
"""
|
||||||
Gather the output of the last layer for cross entropy computation
|
Gather the output of the last layer for cross entropy computation
|
||||||
"""
|
"""
|
||||||
|
sp_group = shard_config.sequence_parallel_process_group
|
||||||
|
sp_mode = shard_config.sequence_parallelism_mode
|
||||||
|
fp8_comm = shard_config.fp8_communication
|
||||||
|
if dist.get_world_size(sp_group) == 1:
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
# Rescale grad (HybridParallelPlugin applies ZeRO grad averaging on the DP * SP group)
|
# Rescale grad (HybridParallelPlugin applies ZeRO grad averaging on the DP * SP group)
|
||||||
scale = None if is_share_sp_tp(sp_mode) else dist.get_world_size(sp_group)
|
scale = None if is_share_sp_tp(sp_mode) else dist.get_world_size(sp_group)
|
||||||
hidden_states = gather_forward_split_backward(
|
hidden_states = gather_forward_split_backward(
|
||||||
hidden_states, 1, sp_group, grad_scale=scale, fp8_communication=fp8_communication
|
hidden_states, sp_dim, sp_group, grad_scale=scale, fp8_communication=fp8_comm
|
||||||
)
|
)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
|
@ -433,7 +433,6 @@ class RingAttention(torch.autograd.Function):
|
||||||
assert (
|
assert (
|
||||||
sp_size % inner_ring_size == 0
|
sp_size % inner_ring_size == 0
|
||||||
), f"sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}"
|
), f"sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}"
|
||||||
|
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Using 2D Ring Attention with inner ring size {inner_ring_size} to maximze NIC util for inter-node comm. Cross your fingers for speed-ups!",
|
f"Using 2D Ring Attention with inner ring size {inner_ring_size} to maximze NIC util for inter-node comm. Cross your fingers for speed-ups!",
|
||||||
|
@ -898,6 +897,7 @@ class RingAttention(torch.autograd.Function):
|
||||||
|
|
||||||
local_sp_rank = dist.get_rank(sp_group)
|
local_sp_rank = dist.get_rank(sp_group)
|
||||||
sp_size = dist.get_world_size(sp_group)
|
sp_size = dist.get_world_size(sp_group)
|
||||||
|
|
||||||
# Using separate streams (pg) for concurrent kv and dkv comm may
|
# Using separate streams (pg) for concurrent kv and dkv comm may
|
||||||
# cause NCCL "software caused connection abort" here...
|
# cause NCCL "software caused connection abort" here...
|
||||||
local_kv_comm = RingComm(local_kv_group)
|
local_kv_comm = RingComm(local_kv_group)
|
||||||
|
@ -1119,9 +1119,14 @@ class RingAttention(torch.autograd.Function):
|
||||||
the batch dim to a packed 1d sequence. Contingent on model forward shape definitions.
|
the batch dim to a packed 1d sequence. Contingent on model forward shape definitions.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
inputs_embeds: Packed input embeddings of shape [B, Sq // sp_size, ...].
|
torch.Tensor:
|
||||||
mask_info: A dictionary of mask info.
|
Packed input embeddings of shape [B, Sq // sp_size, ...].
|
||||||
position_ids: Packed position ids of shape [..., Sq // sp_size].
|
|
||||||
|
Dict[str, Any]:
|
||||||
|
A dictionary containing mask info.
|
||||||
|
|
||||||
|
torch.Tensor:
|
||||||
|
Packed position ids of shape [..., Sq // sp_size].
|
||||||
|
|
||||||
"""
|
"""
|
||||||
_load_varlen_helpers()
|
_load_varlen_helpers()
|
||||||
|
|
|
@ -153,7 +153,6 @@ def dist_cross_entropy(
|
||||||
labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
|
labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
|
||||||
logits: torch.Tensor, # [B, S, Vocab_size]
|
logits: torch.Tensor, # [B, S, Vocab_size]
|
||||||
shard_config: ShardConfig,
|
shard_config: ShardConfig,
|
||||||
out_features: int,
|
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
seq_dim: int = 1,
|
seq_dim: int = 1,
|
||||||
|
@ -226,13 +225,13 @@ def dist_cross_entropy(
|
||||||
logits,
|
logits,
|
||||||
labels,
|
labels,
|
||||||
process_group=shard_config.tensor_parallel_process_group,
|
process_group=shard_config.tensor_parallel_process_group,
|
||||||
vocab_size=out_features,
|
vocab_size=vocab_size,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
mode="sum",
|
mode="sum",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# NOTE if use TP and not parallel_output, the output is gathered in VocabParallelLMHead1D
|
# NOTE if use TP and not parallel_output, the output is gathered in VocabParallelLMHead1D
|
||||||
logits = logits.view(-1, vocab_size)
|
logits = logits.view(-1, logits.size(-1))
|
||||||
loss = loss_fct(logits, labels)
|
loss = loss_fct(logits, labels)
|
||||||
|
|
||||||
# Reduce loss instead of gathering logits over seq dim for savings
|
# Reduce loss instead of gathering logits over seq dim for savings
|
||||||
|
|
|
@ -313,19 +313,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||||
|
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
bias = self.bias if not self.skip_bias_add else None
|
bias = self.bias if not self.skip_bias_add else None
|
||||||
|
if self.seq_parallel_mode == "split_gather":
|
||||||
if self.seq_parallel_mode is None:
|
|
||||||
# Set up backprop all-reduce.
|
|
||||||
input_parallel = reduce_backward(input_, self.process_group, fp8_communication=self.fp8_communication)
|
|
||||||
output_parallel = matmul_with_async_comm(
|
|
||||||
input_parallel,
|
|
||||||
self.weight,
|
|
||||||
bias,
|
|
||||||
self.process_group,
|
|
||||||
self.async_communication,
|
|
||||||
fp8_communication=self.fp8_communication,
|
|
||||||
)
|
|
||||||
elif self.seq_parallel_mode == "split_gather":
|
|
||||||
input_parallel = input_
|
input_parallel = input_
|
||||||
output_parallel = matmul_gather_forward_reducescatter_backward(
|
output_parallel = matmul_gather_forward_reducescatter_backward(
|
||||||
input_parallel,
|
input_parallel,
|
||||||
|
@ -340,8 +328,29 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||||
elif self.seq_parallel_mode == "ring":
|
elif self.seq_parallel_mode == "ring":
|
||||||
input_parallel = input_
|
input_parallel = input_
|
||||||
output_parallel = matmul_gather_forward_reducescatter_backward(
|
output_parallel = matmul_gather_forward_reducescatter_backward(
|
||||||
input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap, True
|
input_parallel,
|
||||||
|
self.weight,
|
||||||
|
bias,
|
||||||
|
self.process_group,
|
||||||
|
True,
|
||||||
|
1,
|
||||||
|
self.overlap,
|
||||||
|
True,
|
||||||
|
fp8_communication=self.fp8_communication,
|
||||||
)
|
)
|
||||||
|
elif self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn":
|
||||||
|
# Set up backprop all-reduce.
|
||||||
|
input_parallel = reduce_backward(input_, self.process_group)
|
||||||
|
output_parallel = matmul_with_async_comm(
|
||||||
|
input_parallel,
|
||||||
|
self.weight,
|
||||||
|
bias,
|
||||||
|
self.process_group,
|
||||||
|
self.async_communication,
|
||||||
|
fp8_communication=self.fp8_communication,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"seq_parallel_mode={self.seq_parallel_mode} is not supported!")
|
||||||
|
|
||||||
if self.gather_output:
|
if self.gather_output:
|
||||||
# All-gather across the partitions.
|
# All-gather across the partitions.
|
||||||
|
@ -553,7 +562,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
||||||
handle.wait()
|
handle.wait()
|
||||||
output = torch.cat(output_parallel_list, dim=-1)
|
output = torch.cat(output_parallel_list, dim=-1)
|
||||||
else:
|
else:
|
||||||
if self.seq_parallel_mode is None:
|
if self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn":
|
||||||
output_parallel = torch.matmul(input_, self.weight)
|
output_parallel = torch.matmul(input_, self.weight)
|
||||||
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
|
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
|
||||||
elif self.seq_parallel_mode == "split_gather":
|
elif self.seq_parallel_mode == "split_gather":
|
||||||
|
@ -567,8 +576,12 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
||||||
elif self.seq_parallel_mode == "ring":
|
elif self.seq_parallel_mode == "ring":
|
||||||
output_parallel = torch.matmul(input_, self.weight)
|
output_parallel = torch.matmul(input_, self.weight)
|
||||||
output = reducescatter_forward_gather_backward(
|
output = reducescatter_forward_gather_backward(
|
||||||
output_parallel, self.process_group, 1, self.fp8_communication
|
output_parallel,
|
||||||
|
self.process_group,
|
||||||
|
1,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"seq_parallel_mode={self.seq_parallel_mode} is not supported!")
|
||||||
|
|
||||||
if not self.skip_bias_add:
|
if not self.skip_bias_add:
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
|
|
|
@ -309,6 +309,9 @@ def split_batch_zigzag(
|
||||||
"""
|
"""
|
||||||
sp_size = dist.get_world_size(sp_group)
|
sp_size = dist.get_world_size(sp_group)
|
||||||
sp_rank = dist.get_rank(sp_group)
|
sp_rank = dist.get_rank(sp_group)
|
||||||
|
if sp_size == 1:
|
||||||
|
return batch
|
||||||
|
|
||||||
if isinstance(batch, torch.Tensor):
|
if isinstance(batch, torch.Tensor):
|
||||||
batch = [batch]
|
batch = [batch]
|
||||||
seq_dim = seq_dim if seq_dim != -1 else batch[0].dim() - 1
|
seq_dim = seq_dim if seq_dim != -1 else batch[0].dim() - 1
|
||||||
|
@ -364,6 +367,9 @@ def split_varlen_zigzag(
|
||||||
"""
|
"""
|
||||||
sp_size = dist.get_world_size(sp_group)
|
sp_size = dist.get_world_size(sp_group)
|
||||||
sp_rank = dist.get_rank(sp_group)
|
sp_rank = dist.get_rank(sp_group)
|
||||||
|
if sp_size == 1:
|
||||||
|
return batch
|
||||||
|
|
||||||
if is_2d:
|
if is_2d:
|
||||||
assert max_seqlen > 0, "max_seqlen must be provided for 2D input"
|
assert max_seqlen > 0, "max_seqlen must be provided for 2D input"
|
||||||
|
|
||||||
|
|
|
@ -365,12 +365,13 @@ class BloomPipelineForwards:
|
||||||
hidden_states = transformer_outputs[0]
|
hidden_states = transformer_outputs[0]
|
||||||
lm_logits = self.lm_head(hidden_states).contiguous()
|
lm_logits = self.lm_head(hidden_states).contiguous()
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
loss = dist_cross_entropy(
|
loss = dist_cross_entropy(
|
||||||
labels,
|
labels,
|
||||||
lm_logits,
|
lm_logits,
|
||||||
shard_config,
|
shard_config,
|
||||||
self.lm_head.out_features,
|
self.lm_head.out_features,
|
||||||
self.config.vocab_size,
|
|
||||||
self.transformer.dtype,
|
self.transformer.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1036,8 +1037,10 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||||
hidden_states = transformer_outputs[0]
|
hidden_states = transformer_outputs[0]
|
||||||
lm_logits = self.lm_head(hidden_states)
|
lm_logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
loss = dist_cross_entropy(
|
loss = dist_cross_entropy(
|
||||||
labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype
|
labels, lm_logits, shard_config, self.lm_head.out_features, self.transformer.dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
|
|
|
@ -4,7 +4,6 @@ from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch.nn import CrossEntropyLoss
|
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
@ -13,10 +12,13 @@ from colossalai.shardformer import ShardConfig
|
||||||
from colossalai.shardformer.layer import AttnMaskType, ColoAttention
|
from colossalai.shardformer.layer import AttnMaskType, ColoAttention
|
||||||
from colossalai.shardformer.layer._operation import (
|
from colossalai.shardformer.layer._operation import (
|
||||||
all_to_all_comm,
|
all_to_all_comm,
|
||||||
gather_forward_split_backward,
|
gather_sp_output,
|
||||||
|
is_share_sp_tp,
|
||||||
split_forward_gather_backward,
|
split_forward_gather_backward,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from ..layer import dist_cross_entropy
|
||||||
|
|
||||||
|
|
||||||
def get_flash_core_attention_forward():
|
def get_flash_core_attention_forward():
|
||||||
from .chatglm2_6b.modeling_chatglm import CoreAttention
|
from .chatglm2_6b.modeling_chatglm import CoreAttention
|
||||||
|
@ -138,6 +140,7 @@ class ChatGLMPipelineForwards:
|
||||||
hidden_states: Optional[torch.FloatTensor] = None,
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
stage_index: Optional[List[int]] = None,
|
stage_index: Optional[List[int]] = None,
|
||||||
shard_config: ShardConfig = None,
|
shard_config: ShardConfig = None,
|
||||||
|
force_sp_output_gather: Optional[bool] = True,
|
||||||
):
|
):
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
|
@ -180,6 +183,15 @@ class ChatGLMPipelineForwards:
|
||||||
if full_attention_mask is None:
|
if full_attention_mask is None:
|
||||||
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
|
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
|
||||||
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
|
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
|
||||||
|
|
||||||
|
# Support SP + PP
|
||||||
|
sp_size = shard_config.sequence_parallel_size
|
||||||
|
sp_mode = shard_config.sequence_parallelism_mode
|
||||||
|
sp_group = shard_config.sequence_parallel_process_group
|
||||||
|
# For generating full positions ids (the states will be gathered along the seq dim before attention fwd).
|
||||||
|
if sp_mode != "ring_attn" and not stage_manager.is_first_stage():
|
||||||
|
seq_length *= sp_size
|
||||||
|
|
||||||
# Rotary positional embeddings
|
# Rotary positional embeddings
|
||||||
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
|
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
|
||||||
if position_ids is not None:
|
if position_ids is not None:
|
||||||
|
@ -200,21 +212,14 @@ class ChatGLMPipelineForwards:
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||||
|
|
||||||
if shard_config and shard_config.enable_sequence_parallelism:
|
# Keep the input split across all PP stages
|
||||||
if shard_config.sequence_parallelism_mode == "split_gather":
|
if stage_manager.is_first_stage():
|
||||||
|
if shard_config.enable_sequence_parallelism:
|
||||||
|
if sp_mode == "split_gather":
|
||||||
hidden_states = split_forward_gather_backward(
|
hidden_states = split_forward_gather_backward(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
dim=0,
|
dim=0,
|
||||||
process_group=shard_config.tensor_parallel_process_group,
|
process_group=sp_group,
|
||||||
fp8_communication=shard_config.fp8_communication,
|
|
||||||
)
|
|
||||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
|
||||||
hidden_states = split_forward_gather_backward(
|
|
||||||
hidden_states,
|
|
||||||
dim=0,
|
|
||||||
process_group=shard_config.sequence_parallel_process_group,
|
|
||||||
grad_scale=1 / shard_config.sequence_parallel_size,
|
|
||||||
fp8_communication=shard_config.fp8_communication,
|
|
||||||
)
|
)
|
||||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
||||||
hidden_states = split_forward_gather_backward(
|
hidden_states = split_forward_gather_backward(
|
||||||
|
@ -223,6 +228,7 @@ class ChatGLMPipelineForwards:
|
||||||
process_group=shard_config.sequence_parallel_process_group,
|
process_group=shard_config.sequence_parallel_process_group,
|
||||||
grad_scale=1 / shard_config.sequence_parallel_size,
|
grad_scale=1 / shard_config.sequence_parallel_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
for idx in range(start_idx, end_idx):
|
for idx in range(start_idx, end_idx):
|
||||||
layer = self.encoder._get_layer(idx)
|
layer = self.encoder._get_layer(idx)
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
|
@ -248,35 +254,19 @@ class ChatGLMPipelineForwards:
|
||||||
if use_cache:
|
if use_cache:
|
||||||
presents = presents + (kv_cache,)
|
presents = presents + (kv_cache,)
|
||||||
|
|
||||||
if shard_config and shard_config.enable_sequence_parallelism:
|
|
||||||
if shard_config.sequence_parallelism_mode == "split_gather":
|
|
||||||
hidden_states = gather_forward_split_backward(
|
|
||||||
hidden_states,
|
|
||||||
dim=0,
|
|
||||||
process_group=shard_config.tensor_parallel_process_group,
|
|
||||||
fp8_communication=shard_config.fp8_communication,
|
|
||||||
)
|
|
||||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
|
||||||
hidden_states = gather_forward_split_backward(
|
|
||||||
hidden_states,
|
|
||||||
dim=0,
|
|
||||||
process_group=shard_config.sequence_parallel_process_group,
|
|
||||||
grad_scale=shard_config.sequence_parallel_size,
|
|
||||||
fp8_communication=shard_config.fp8_communication,
|
|
||||||
)
|
|
||||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
|
||||||
hidden_states = gather_forward_split_backward(
|
|
||||||
hidden_states,
|
|
||||||
dim=0,
|
|
||||||
process_group=shard_config.sequence_parallel_process_group,
|
|
||||||
grad_scale=shard_config.sequence_parallel_size,
|
|
||||||
)
|
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage():
|
||||||
# final layer_norm
|
# final layer_norm
|
||||||
if self.encoder.post_layer_norm:
|
if self.encoder.post_layer_norm:
|
||||||
hidden_states = self.encoder.final_layernorm(hidden_states)
|
hidden_states = self.encoder.final_layernorm(hidden_states)
|
||||||
|
|
||||||
|
# Gather seq-wise in the final output stage
|
||||||
|
if shard_config.enable_sequence_parallelism:
|
||||||
|
sp_mode = shard_config.sequence_parallelism_mode
|
||||||
|
if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode):
|
||||||
|
hidden_states = gather_sp_output(hidden_states, shard_config, sp_dim=0)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(
|
return tuple(
|
||||||
v
|
v
|
||||||
|
@ -333,6 +323,7 @@ class ChatGLMPipelineForwards:
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
stage_index=stage_index,
|
stage_index=stage_index,
|
||||||
shard_config=shard_config,
|
shard_config=shard_config,
|
||||||
|
force_sp_output_gather=False,
|
||||||
)
|
)
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage():
|
||||||
hidden_states = transformer_outputs[0]
|
hidden_states = transformer_outputs[0]
|
||||||
|
@ -340,17 +331,21 @@ class ChatGLMPipelineForwards:
|
||||||
hidden_states = hidden_states[-1:]
|
hidden_states = hidden_states[-1:]
|
||||||
lm_logits = self.transformer.output_layer(hidden_states)
|
lm_logits = self.transformer.output_layer(hidden_states)
|
||||||
lm_logits = lm_logits.transpose(0, 1).contiguous()
|
lm_logits = lm_logits.transpose(0, 1).contiguous()
|
||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
lm_logits = lm_logits.to(torch.float32)
|
# ChatGLM doesn't have lm_head split
|
||||||
# Shift so that tokens < n predict n
|
enable_tp = shard_config.enable_tensor_parallelism
|
||||||
shift_logits = lm_logits[..., :-1, :].contiguous()
|
shard_config.enable_tensor_parallelism = False
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
loss = dist_cross_entropy(
|
||||||
# Flatten the tokens
|
labels,
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=-100)
|
lm_logits,
|
||||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
shard_config,
|
||||||
lm_logits = lm_logits.to(hidden_states.dtype)
|
self.transformer.output_layer.out_features,
|
||||||
loss = loss.to(hidden_states.dtype)
|
lm_logits.dtype,
|
||||||
|
)
|
||||||
|
shard_config.enable_tensor_parallelism = enable_tp
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (lm_logits,) + transformer_outputs[1:]
|
output = (lm_logits,) + transformer_outputs[1:]
|
||||||
return ((loss,) + output) if loss is not None else output
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
@ -379,6 +374,7 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
force_sp_output_gather: Optional[bool] = True,
|
||||||
):
|
):
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
@ -456,22 +452,9 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
)
|
)
|
||||||
|
if shard_config.enable_sequence_parallelism:
|
||||||
if sp_mode in ["split_gather"]:
|
if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode):
|
||||||
hidden_states = gather_forward_split_backward(
|
hidden_states = gather_sp_output(hidden_states, shard_config, sp_dim=0)
|
||||||
hidden_states,
|
|
||||||
dim=0,
|
|
||||||
process_group=shard_config.tensor_parallel_process_group,
|
|
||||||
fp8_communication=shard_config.fp8_communication,
|
|
||||||
)
|
|
||||||
elif sp_mode == "all_to_all":
|
|
||||||
hidden_states = gather_forward_split_backward(
|
|
||||||
hidden_states,
|
|
||||||
dim=0,
|
|
||||||
process_group=sp_group,
|
|
||||||
grad_scale=sp_size,
|
|
||||||
fp8_communication=shard_config.fp8_communication,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(
|
return tuple(
|
||||||
|
|
|
@ -17,14 +17,13 @@ from transformers.models.cohere.modeling_cohere import (
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.shardformer.layer._operation import (
|
from colossalai.shardformer.layer._operation import all_to_all_comm, split_forward_gather_backward
|
||||||
all_to_all_comm,
|
|
||||||
gather_forward_split_backward,
|
|
||||||
split_forward_gather_backward,
|
|
||||||
)
|
|
||||||
from colossalai.shardformer.shard import ShardConfig
|
from colossalai.shardformer.shard import ShardConfig
|
||||||
|
|
||||||
from ..layer import ColoAttention, dist_cross_entropy
|
from ..layer import ColoAttention, dist_cross_entropy
|
||||||
|
from ..layer._operation import gather_sp_output, is_share_sp_tp
|
||||||
|
|
||||||
|
_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring"]
|
||||||
|
|
||||||
_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]
|
_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]
|
||||||
|
|
||||||
|
@ -52,6 +51,7 @@ class CommandPipelineForwards:
|
||||||
hidden_states: Optional[torch.FloatTensor] = None,
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
stage_index: Optional[List[int]] = None,
|
stage_index: Optional[List[int]] = None,
|
||||||
shard_config: ShardConfig = None,
|
shard_config: ShardConfig = None,
|
||||||
|
force_sp_output_gather: bool = True,
|
||||||
):
|
):
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
@ -93,10 +93,16 @@ class CommandPipelineForwards:
|
||||||
if not isinstance(past_key_values, StaticCache):
|
if not isinstance(past_key_values, StaticCache):
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
past_seen_tokens = past_key_values.get_seq_length()
|
past_seen_tokens = past_key_values.get_seq_length()
|
||||||
|
|
||||||
|
# NOTE: For generating full positions ids
|
||||||
|
# (the states will be gathered along the seq dim before attention fwd).
|
||||||
|
if shard_config.sequence_parallelism_mode != "ring_attn" and not stage_manager.is_first_stage():
|
||||||
|
seq_length *= shard_config.sequence_parallel_size
|
||||||
|
|
||||||
if cache_position is None:
|
if cache_position is None:
|
||||||
if isinstance(past_key_values, StaticCache):
|
if isinstance(past_key_values, StaticCache):
|
||||||
raise ValueError("cache_position is a required argument when using StaticCache.")
|
raise ValueError("cache_position is a required argument when using StaticCache.")
|
||||||
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=device)
|
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=device)
|
||||||
|
|
||||||
seq_length_with_past = seq_length + past_seen_tokens
|
seq_length_with_past = seq_length + past_seen_tokens
|
||||||
|
|
||||||
|
@ -136,7 +142,7 @@ class CommandPipelineForwards:
|
||||||
)
|
)
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
|
||||||
if shard_config and shard_config.enable_sequence_parallelism:
|
if stage_manager.is_first_stage() and shard_config.enable_sequence_parallelism:
|
||||||
if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
|
if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
|
||||||
hidden_states = split_forward_gather_backward(
|
hidden_states = split_forward_gather_backward(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
|
@ -208,23 +214,10 @@ class CommandPipelineForwards:
|
||||||
|
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage():
|
||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
|
sp_mode = shard_config.sequence_parallelism_mode
|
||||||
if shard_config and shard_config.enable_sequence_parallelism:
|
if shard_config.enable_sequence_parallelism:
|
||||||
if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
|
if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode):
|
||||||
hidden_states = gather_forward_split_backward(
|
hidden_states = gather_sp_output(hidden_states, shard_config)
|
||||||
hidden_states,
|
|
||||||
dim=1,
|
|
||||||
process_group=shard_config.tensor_parallel_process_group,
|
|
||||||
fp8_communication=shard_config.fp8_communication,
|
|
||||||
)
|
|
||||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
|
||||||
hidden_states = gather_forward_split_backward(
|
|
||||||
hidden_states,
|
|
||||||
dim=1,
|
|
||||||
process_group=shard_config.sequence_parallel_process_group,
|
|
||||||
grad_scale=shard_config.sequence_parallel_size,
|
|
||||||
fp8_communication=shard_config.fp8_communication,
|
|
||||||
)
|
|
||||||
|
|
||||||
# add hidden states from the last decoder layer
|
# add hidden states from the last decoder layer
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
|
@ -327,6 +320,7 @@ class CommandPipelineForwards:
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
stage_index=stage_index,
|
stage_index=stage_index,
|
||||||
shard_config=shard_config,
|
shard_config=shard_config,
|
||||||
|
force_sp_output_gather=False,
|
||||||
)
|
)
|
||||||
past_key_values = None
|
past_key_values = None
|
||||||
|
|
||||||
|
@ -335,9 +329,10 @@ class CommandPipelineForwards:
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
logits = logits * self.logit_scale
|
logits = logits * self.logit_scale
|
||||||
logits = logits.float()
|
logits = logits.float()
|
||||||
loss = dist_cross_entropy(
|
|
||||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
|
loss = None
|
||||||
)
|
if labels is not None:
|
||||||
|
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[1:]
|
output = (logits,) + outputs[1:]
|
||||||
|
@ -482,6 +477,7 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
force_sp_output_gather: bool = True,
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
|
@ -584,14 +580,10 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
|
||||||
|
|
||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
if sp_mode == "ring" or sp_mode == "split_gather":
|
# Cases that don't support parallelizing cross entropy computation along sequence
|
||||||
hidden_states = gather_forward_split_backward(
|
if shard_config.enable_sequence_parallelism:
|
||||||
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
if (not shard_config.parallel_output) or is_share_sp_tp(sp_mode) or force_sp_output_gather:
|
||||||
)
|
hidden_states = gather_sp_output(hidden_states, shard_config)
|
||||||
elif sp_mode == "all_to_all":
|
|
||||||
hidden_states = gather_forward_split_backward(
|
|
||||||
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
|
|
||||||
)
|
|
||||||
|
|
||||||
# add hidden states from the last decoder layer
|
# add hidden states from the last decoder layer
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
|
@ -676,6 +668,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
|
force_sp_output_gather=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
|
@ -683,12 +676,14 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
logits = logits * self.logit_scale
|
logits = logits * self.logit_scale
|
||||||
logits = logits.float()
|
logits = logits.float()
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
loss = dist_cross_entropy(
|
loss = dist_cross_entropy(
|
||||||
labels,
|
labels,
|
||||||
logits,
|
logits,
|
||||||
shard_config,
|
shard_config,
|
||||||
self.lm_head.out_features,
|
self.lm_head.out_features,
|
||||||
self.config.vocab_size,
|
|
||||||
self.model.dtype,
|
self.model.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -21,8 +21,9 @@ from transformers.models.gpt2.modeling_gpt2 import (
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.shardformer.layer import ColoAttention
|
from colossalai.shardformer.layer import ColoAttention, RingAttention
|
||||||
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
|
from colossalai.shardformer.layer._operation import gather_sp_output, split_forward_gather_backward
|
||||||
|
from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag
|
||||||
from colossalai.shardformer.shard import ShardConfig
|
from colossalai.shardformer.shard import ShardConfig
|
||||||
|
|
||||||
from ..layer import dist_cross_entropy
|
from ..layer import dist_cross_entropy
|
||||||
|
@ -39,10 +40,16 @@ def _get_attention_mask(
|
||||||
encoder_hidden_states: Optional[torch.Tensor],
|
encoder_hidden_states: Optional[torch.Tensor],
|
||||||
encoder_attention_mask: Optional[torch.FloatTensor],
|
encoder_attention_mask: Optional[torch.FloatTensor],
|
||||||
) -> Tuple[Optional[Union[torch.Tensor, dict]], Optional[Union[torch.Tensor, dict]]]:
|
) -> Tuple[Optional[Union[torch.Tensor, dict]], Optional[Union[torch.Tensor, dict]]]:
|
||||||
batch_size, seq_len = hidden_states.shape[:2]
|
# Received input is already split for non-first pipeline stages,
|
||||||
|
# but attn mask isn't
|
||||||
|
batch_size = hidden_states.size(0)
|
||||||
|
seq_len = attention_mask.size(-1)
|
||||||
|
|
||||||
|
sp_mode = shard_config.sequence_parallelism_mode
|
||||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||||
if self.config.add_cross_attention and encoder_hidden_states is not None:
|
if self.config.add_cross_attention and encoder_hidden_states is not None:
|
||||||
|
assert not sp_mode == "ring_attn", "Ring Attention only supports decoder-only."
|
||||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||||
if shard_config.enable_flash_attention:
|
if shard_config.enable_flash_attention:
|
||||||
encoder_attention_mask = ColoAttention.prepare_attn_kwargs(
|
encoder_attention_mask = ColoAttention.prepare_attn_kwargs(
|
||||||
|
@ -62,6 +69,7 @@ def _get_attention_mask(
|
||||||
encoder_attention_mask = {"attention_mask": None}
|
encoder_attention_mask = {"attention_mask": None}
|
||||||
else:
|
else:
|
||||||
encoder_attention_mask = None
|
encoder_attention_mask = None
|
||||||
|
|
||||||
# GPT2Attention mask.
|
# GPT2Attention mask.
|
||||||
past_key_values_length = 0
|
past_key_values_length = 0
|
||||||
if past_key_values is not None and past_key_values[0] is not None:
|
if past_key_values is not None and past_key_values[0] is not None:
|
||||||
|
@ -69,6 +77,7 @@ def _get_attention_mask(
|
||||||
if shard_config.enable_flash_attention:
|
if shard_config.enable_flash_attention:
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attention_mask = attention_mask.view(batch_size, -1)
|
attention_mask = attention_mask.view(batch_size, -1)
|
||||||
|
|
||||||
attention_mask = ColoAttention.prepare_attn_kwargs(
|
attention_mask = ColoAttention.prepare_attn_kwargs(
|
||||||
(batch_size, 1, seq_len, seq_len + past_key_values_length),
|
(batch_size, 1, seq_len, seq_len + past_key_values_length),
|
||||||
hidden_states.dtype,
|
hidden_states.dtype,
|
||||||
|
@ -123,6 +132,7 @@ class GPT2PipelineForwards:
|
||||||
hidden_states: Optional[torch.FloatTensor] = None,
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
stage_index: Optional[List[int]] = None,
|
stage_index: Optional[List[int]] = None,
|
||||||
shard_config: ShardConfig = None,
|
shard_config: ShardConfig = None,
|
||||||
|
force_sp_gather: Optional[bool] = True,
|
||||||
) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||||
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward.
|
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward.
|
||||||
# Please refer to original code of transformers for more details.
|
# Please refer to original code of transformers for more details.
|
||||||
|
@ -146,16 +156,15 @@ class GPT2PipelineForwards:
|
||||||
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
|
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
|
||||||
if stage_manager.is_first_stage():
|
disable_pp = stage_manager is None
|
||||||
|
if disable_pp or stage_manager.is_first_stage():
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
elif input_ids is not None:
|
elif input_ids is not None:
|
||||||
input_shape = input_ids.size()
|
input_shape = input_ids.size()
|
||||||
input_ids = input_ids.view(-1, input_shape[-1])
|
input_ids = input_ids.view(-1, input_shape[-1])
|
||||||
input_ids.shape[0]
|
|
||||||
elif inputs_embeds is not None:
|
elif inputs_embeds is not None:
|
||||||
input_shape = inputs_embeds.size()[:-1]
|
input_shape = inputs_embeds.size()[:-1]
|
||||||
inputs_embeds.shape[0]
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
@ -176,7 +185,7 @@ class GPT2PipelineForwards:
|
||||||
# head_mask has shape n_layer x batch x n_heads x N x N
|
# head_mask has shape n_layer x batch x n_heads x N x N
|
||||||
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||||
|
|
||||||
if stage_manager.is_first_stage():
|
if disable_pp or stage_manager.is_first_stage():
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = torch.arange(0, input_shape[-1], dtype=torch.long, device=device)
|
position_ids = torch.arange(0, input_shape[-1], dtype=torch.long, device=device)
|
||||||
position_ids = position_ids.unsqueeze(0)
|
position_ids = position_ids.unsqueeze(0)
|
||||||
|
@ -190,9 +199,7 @@ class GPT2PipelineForwards:
|
||||||
hidden_states = hidden_states + token_type_embeds
|
hidden_states = hidden_states + token_type_embeds
|
||||||
hidden_states = self.drop(hidden_states)
|
hidden_states = self.drop(hidden_states)
|
||||||
|
|
||||||
output_shape = input_shape + (hidden_states.size(-1),)
|
attn_kwargs, encoder_attention_mask = _get_attention_mask(
|
||||||
|
|
||||||
attention_mask, encoder_attention_mask = _get_attention_mask(
|
|
||||||
self,
|
self,
|
||||||
shard_config,
|
shard_config,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
|
@ -215,23 +222,43 @@ class GPT2PipelineForwards:
|
||||||
|
|
||||||
# split the input tensor along sequence dimension
|
# split the input tensor along sequence dimension
|
||||||
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
||||||
if shard_config and shard_config.enable_sequence_parallelism:
|
sp_mode = shard_config.sequence_parallelism_mode
|
||||||
if shard_config.sequence_parallelism_mode == "split_gather":
|
sp_group = shard_config.sequence_parallel_process_group
|
||||||
|
if disable_pp or stage_manager.is_first_stage():
|
||||||
|
# Ring Attention's special zigzag batch processing
|
||||||
|
if sp_mode == "ring_attn":
|
||||||
|
assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention."
|
||||||
|
if not attention_mask.bool().all():
|
||||||
|
hidden_states, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch(
|
||||||
|
attention_mask, sp_group, hidden_states, position_ids
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states, position_ids = split_batch_zigzag([hidden_states, position_ids], sp_group)
|
||||||
|
# Other sp modes
|
||||||
|
else:
|
||||||
|
if sp_mode == "split_gather":
|
||||||
hidden_states = split_forward_gather_backward(
|
hidden_states = split_forward_gather_backward(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
dim=1,
|
dim=1,
|
||||||
process_group=shard_config.tensor_parallel_process_group,
|
process_group=shard_config.tensor_parallel_process_group,
|
||||||
fp8_communication=shard_config.fp8_communication,
|
|
||||||
)
|
)
|
||||||
|
elif sp_mode == "ring_attn":
|
||||||
|
# Later stages already received split hidden states
|
||||||
|
_, attn_kwargs, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group)
|
||||||
|
del attention_mask
|
||||||
|
|
||||||
# Going through held blocks.
|
# Going through held blocks.
|
||||||
|
if disable_pp:
|
||||||
|
start_idx, end_idx = 0, len(self.h)
|
||||||
|
else:
|
||||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||||
|
|
||||||
for i in range(start_idx, end_idx):
|
for i in range(start_idx, end_idx):
|
||||||
block = self.h[i]
|
block = self.h[i]
|
||||||
torch.cuda.set_device(hidden_states.device)
|
torch.cuda.set_device(hidden_states.device)
|
||||||
# Ensure that attention_mask is always on the same device as hidden_states
|
# Ensure that attention_mask is always on the same device as hidden_states
|
||||||
if torch.is_tensor(attention_mask):
|
if torch.is_tensor(attn_kwargs):
|
||||||
attention_mask = attention_mask.to(hidden_states.device)
|
attn_kwargs = attn_kwargs.to(hidden_states.device)
|
||||||
if isinstance(head_mask, torch.Tensor):
|
if isinstance(head_mask, torch.Tensor):
|
||||||
head_mask = head_mask.to(hidden_states.device)
|
head_mask = head_mask.to(hidden_states.device)
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
|
@ -242,7 +269,7 @@ class GPT2PipelineForwards:
|
||||||
block.__call__,
|
block.__call__,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
None,
|
None,
|
||||||
attention_mask,
|
attn_kwargs,
|
||||||
head_mask[i],
|
head_mask[i],
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
|
@ -253,7 +280,7 @@ class GPT2PipelineForwards:
|
||||||
outputs = block(
|
outputs = block(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
layer_past=None,
|
layer_past=None,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attn_kwargs,
|
||||||
head_mask=head_mask[i],
|
head_mask=head_mask[i],
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
@ -270,26 +297,25 @@ class GPT2PipelineForwards:
|
||||||
if self.config.add_cross_attention:
|
if self.config.add_cross_attention:
|
||||||
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
||||||
|
|
||||||
# When sequence parallelism done, gather the output tensor in forward and split it in backward
|
# When sequence parallelism is done, gather the output tensor in forward and split it in backward
|
||||||
if shard_config and shard_config.enable_sequence_parallelism:
|
gather_output = (not shard_config.parallel_output) or force_sp_gather or is_share_sp_tp(sp_mode)
|
||||||
if shard_config.sequence_parallelism_mode == "split_gather":
|
if disable_pp or stage_manager.is_last_stage():
|
||||||
hidden_states = gather_forward_split_backward(
|
if gather_output:
|
||||||
hidden_states,
|
hidden_states = gather_sp_output(hidden_states, shard_config)
|
||||||
dim=1,
|
|
||||||
process_group=shard_config.tensor_parallel_process_group,
|
|
||||||
fp8_communication=shard_config.fp8_communication,
|
|
||||||
)
|
|
||||||
|
|
||||||
if stage_manager.is_last_stage():
|
# gather_sp_output could've changed seq length.
|
||||||
|
input_shape = (*input_shape[:-1], hidden_states.size(-2))
|
||||||
|
output_shape = input_shape + (hidden_states.size(-1),)
|
||||||
|
|
||||||
|
if disable_pp or stage_manager.is_last_stage():
|
||||||
hidden_states = self.ln_f(hidden_states)
|
hidden_states = self.ln_f(hidden_states)
|
||||||
|
|
||||||
hidden_states = hidden_states.view(output_shape)
|
hidden_states = hidden_states.view(output_shape)
|
||||||
|
|
||||||
# Add last hidden state
|
# Add last hidden state
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
if stage_manager.is_last_stage():
|
if disable_pp or stage_manager.is_last_stage():
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(
|
return tuple(
|
||||||
v
|
v
|
||||||
|
@ -366,16 +392,28 @@ class GPT2PipelineForwards:
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
stage_index=stage_index,
|
stage_index=stage_index,
|
||||||
shard_config=shard_config,
|
shard_config=shard_config,
|
||||||
|
force_sp_gather=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# If not at the last stage, return hidden_states as in GPT2Model
|
# If not at the last stage, return hidden_states as in GPT2Model
|
||||||
if not stage_manager.is_last_stage():
|
disable_pp = stage_manager is None
|
||||||
|
if (not disable_pp) and (not stage_manager.is_last_stage()):
|
||||||
return {"hidden_states": outputs["hidden_states"]}
|
return {"hidden_states": outputs["hidden_states"]}
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
lm_logits = self.lm_head(hidden_states)
|
lm_logits = self.lm_head(hidden_states)
|
||||||
|
if shard_config.sequence_parallelism_mode == "ring_attn":
|
||||||
|
# Split labels in a zigzag fashion too
|
||||||
|
sp_group = shard_config.sequence_parallel_process_group
|
||||||
|
if not attention_mask.bool().all():
|
||||||
|
# [B, max_seqlen // sp_size]
|
||||||
|
labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True)
|
||||||
|
else:
|
||||||
|
labels = split_batch_zigzag(labels, sp_group, seq_dim=1, is_label=True)
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
loss = dist_cross_entropy(
|
loss = dist_cross_entropy(
|
||||||
labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype
|
labels, lm_logits, shard_config, self.lm_head.out_features, self.transformer.dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
|
@ -770,7 +808,7 @@ class GPT2PipelineForwards:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_gpt2_flash_attention_forward():
|
def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None):
|
||||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
|
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -817,6 +855,21 @@ def get_gpt2_flash_attention_forward():
|
||||||
if self.scale_attn_by_inverse_layer_idx:
|
if self.scale_attn_by_inverse_layer_idx:
|
||||||
scale /= float(self.layer_idx + 1)
|
scale /= float(self.layer_idx + 1)
|
||||||
dropout_p = self.attn_dropout.p if self.training else 0.0
|
dropout_p = self.attn_dropout.p if self.training else 0.0
|
||||||
|
|
||||||
|
sp_mode = shard_config.sequence_parallelism_mode
|
||||||
|
sp_group = shard_config.sequence_parallel_process_group
|
||||||
|
if sp_mode == "ring_attn":
|
||||||
|
attn_output = RingAttention.attention(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
sp_group,
|
||||||
|
**attention_mask,
|
||||||
|
dropout_p=dropout_p,
|
||||||
|
scale=scale,
|
||||||
|
inner_ring_size=shard_config.inner_ring_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale)
|
attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale)
|
||||||
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
||||||
attn_output = self.c_proj(attn_output)
|
attn_output = self.c_proj(attn_output)
|
||||||
|
@ -828,466 +881,6 @@ def get_gpt2_flash_attention_forward():
|
||||||
return forward
|
return forward
|
||||||
|
|
||||||
|
|
||||||
def get_gpt_model_forward_for_flash_attn(shard_config: ShardConfig):
|
|
||||||
def forward(
|
|
||||||
self: GPT2Model,
|
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
token_type_ids: Optional[torch.LongTensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
||||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
|
||||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
||||||
elif input_ids is not None:
|
|
||||||
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
|
||||||
input_shape = input_ids.size()
|
|
||||||
input_ids = input_ids.view(-1, input_shape[-1])
|
|
||||||
input_ids.shape[0]
|
|
||||||
elif inputs_embeds is not None:
|
|
||||||
input_shape = inputs_embeds.size()[:-1]
|
|
||||||
inputs_embeds.shape[0]
|
|
||||||
else:
|
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
||||||
|
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
||||||
|
|
||||||
if token_type_ids is not None:
|
|
||||||
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
|
||||||
if position_ids is not None:
|
|
||||||
position_ids = position_ids.view(-1, input_shape[-1])
|
|
||||||
|
|
||||||
if past_key_values is None:
|
|
||||||
past_length = 0
|
|
||||||
past_key_values = tuple([None] * len(self.h))
|
|
||||||
else:
|
|
||||||
past_length = past_key_values[0][0].size(-2)
|
|
||||||
if position_ids is None:
|
|
||||||
position_ids = torch.arange(
|
|
||||||
past_length,
|
|
||||||
input_shape[-1] + past_length,
|
|
||||||
dtype=torch.long,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
|
||||||
|
|
||||||
# Prepare head mask if needed
|
|
||||||
# 1.0 in head_mask indicate we keep the head
|
|
||||||
# attention_probs has shape bsz x n_heads x N x N
|
|
||||||
# head_mask has shape n_layer x batch x n_heads x N x N
|
|
||||||
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.wte(input_ids)
|
|
||||||
position_embeds = self.wpe(position_ids)
|
|
||||||
hidden_states = inputs_embeds + position_embeds
|
|
||||||
|
|
||||||
if token_type_ids is not None:
|
|
||||||
token_type_embeds = self.wte(token_type_ids)
|
|
||||||
hidden_states = hidden_states + token_type_embeds
|
|
||||||
|
|
||||||
hidden_states = self.drop(hidden_states)
|
|
||||||
|
|
||||||
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
|
|
||||||
|
|
||||||
attention_mask, encoder_attention_mask = _get_attention_mask(
|
|
||||||
self,
|
|
||||||
shard_config,
|
|
||||||
hidden_states,
|
|
||||||
past_key_values,
|
|
||||||
attention_mask,
|
|
||||||
encoder_hidden_states,
|
|
||||||
encoder_attention_mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
|
||||||
if use_cache:
|
|
||||||
logger.warning_once(
|
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
||||||
)
|
|
||||||
use_cache = False
|
|
||||||
|
|
||||||
presents = () if use_cache else None
|
|
||||||
all_self_attentions = () if output_attentions else None
|
|
||||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
|
||||||
all_hidden_states = () if output_hidden_states else None
|
|
||||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
|
||||||
# Model parallel
|
|
||||||
if self.model_parallel:
|
|
||||||
torch.cuda.set_device(hidden_states.device)
|
|
||||||
# Ensure layer_past is on same device as hidden_states (might not be correct)
|
|
||||||
if layer_past is not None:
|
|
||||||
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
|
|
||||||
# Ensure that attention_mask is always on the same device as hidden_states
|
|
||||||
if torch.is_tensor(attention_mask):
|
|
||||||
attention_mask = attention_mask.to(hidden_states.device)
|
|
||||||
if isinstance(head_mask, torch.Tensor):
|
|
||||||
head_mask = head_mask.to(hidden_states.device)
|
|
||||||
if output_hidden_states:
|
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
|
||||||
|
|
||||||
def create_custom_forward(module):
|
|
||||||
def custom_forward(*inputs):
|
|
||||||
# None for past_key_value
|
|
||||||
return module(*inputs, use_cache, output_attentions)
|
|
||||||
|
|
||||||
return custom_forward
|
|
||||||
|
|
||||||
outputs = torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(block),
|
|
||||||
hidden_states,
|
|
||||||
None,
|
|
||||||
attention_mask,
|
|
||||||
head_mask[i],
|
|
||||||
encoder_hidden_states,
|
|
||||||
encoder_attention_mask,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
outputs = block(
|
|
||||||
hidden_states,
|
|
||||||
layer_past=layer_past,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
head_mask=head_mask[i],
|
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
if use_cache is True:
|
|
||||||
presents = presents + (outputs[1],)
|
|
||||||
|
|
||||||
if output_attentions:
|
|
||||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
|
||||||
if self.config.add_cross_attention:
|
|
||||||
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
|
||||||
|
|
||||||
# Model Parallel: If it's the last layer for that device, put things on the next device
|
|
||||||
if self.model_parallel:
|
|
||||||
for k, v in self.device_map.items():
|
|
||||||
if i == v[-1] and "cuda:" + str(k) != self.last_device:
|
|
||||||
hidden_states = hidden_states.to("cuda:" + str(k + 1))
|
|
||||||
|
|
||||||
hidden_states = self.ln_f(hidden_states)
|
|
||||||
|
|
||||||
hidden_states = hidden_states.view(output_shape)
|
|
||||||
# Add last hidden state
|
|
||||||
if output_hidden_states:
|
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
return tuple(
|
|
||||||
v
|
|
||||||
for v in [
|
|
||||||
hidden_states,
|
|
||||||
presents,
|
|
||||||
all_hidden_states,
|
|
||||||
all_self_attentions,
|
|
||||||
all_cross_attentions,
|
|
||||||
]
|
|
||||||
if v is not None
|
|
||||||
)
|
|
||||||
|
|
||||||
return BaseModelOutputWithPastAndCrossAttentions(
|
|
||||||
last_hidden_state=hidden_states,
|
|
||||||
past_key_values=presents,
|
|
||||||
hidden_states=all_hidden_states,
|
|
||||||
attentions=all_self_attentions,
|
|
||||||
cross_attentions=all_cross_attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
return forward
|
|
||||||
|
|
||||||
|
|
||||||
def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
token_type_ids: Optional[torch.LongTensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
||||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
|
||||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
||||||
elif input_ids is not None:
|
|
||||||
input_shape = input_ids.size()
|
|
||||||
input_ids = input_ids.view(-1, input_shape[-1])
|
|
||||||
input_ids.shape[0]
|
|
||||||
elif inputs_embeds is not None:
|
|
||||||
input_shape = inputs_embeds.size()[:-1]
|
|
||||||
inputs_embeds.shape[0]
|
|
||||||
else:
|
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
||||||
|
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
||||||
|
|
||||||
if token_type_ids is not None:
|
|
||||||
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
|
||||||
if position_ids is not None:
|
|
||||||
position_ids = position_ids.view(-1, input_shape[-1])
|
|
||||||
|
|
||||||
if past_key_values is None:
|
|
||||||
past_length = 0
|
|
||||||
past_key_values = tuple([None] * len(self.h))
|
|
||||||
else:
|
|
||||||
past_length = past_key_values[0][0].size(-2)
|
|
||||||
if position_ids is None:
|
|
||||||
position_ids = torch.arange(
|
|
||||||
past_length,
|
|
||||||
input_shape[-1] + past_length,
|
|
||||||
dtype=torch.long,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
|
||||||
|
|
||||||
# Prepare head mask if needed
|
|
||||||
# 1.0 in head_mask indicate we keep the head
|
|
||||||
# attention_probs has shape bsz x n_heads x N x N
|
|
||||||
# head_mask has shape n_layer x batch x n_heads x N x N
|
|
||||||
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.wte(input_ids)
|
|
||||||
position_embeds = self.wpe(position_ids)
|
|
||||||
hidden_states = inputs_embeds + position_embeds
|
|
||||||
|
|
||||||
if token_type_ids is not None:
|
|
||||||
token_type_embeds = self.wte(token_type_ids)
|
|
||||||
hidden_states = hidden_states + token_type_embeds
|
|
||||||
|
|
||||||
hidden_states = self.drop(hidden_states)
|
|
||||||
|
|
||||||
output_shape = input_shape + (hidden_states.size(-1),)
|
|
||||||
attention_mask, encoder_attention_mask = _get_attention_mask(
|
|
||||||
self,
|
|
||||||
shard_config,
|
|
||||||
hidden_states,
|
|
||||||
past_key_values,
|
|
||||||
attention_mask,
|
|
||||||
encoder_hidden_states,
|
|
||||||
encoder_attention_mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
|
||||||
if use_cache:
|
|
||||||
logger = logging.get_logger(__name__)
|
|
||||||
logger.warning_once(
|
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
||||||
)
|
|
||||||
use_cache = False
|
|
||||||
|
|
||||||
presents = () if use_cache else None
|
|
||||||
all_self_attentions = () if output_attentions else None
|
|
||||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
|
||||||
all_hidden_states = () if output_hidden_states else None
|
|
||||||
|
|
||||||
# split the input tensor along sequence dimension
|
|
||||||
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
|
||||||
hidden_states = split_forward_gather_backward(
|
|
||||||
hidden_states,
|
|
||||||
dim=1,
|
|
||||||
process_group=shard_config.sequence_parallel_process_group,
|
|
||||||
fp8_communication=shard_config.fp8_communication,
|
|
||||||
)
|
|
||||||
|
|
||||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
|
||||||
# Model parallel
|
|
||||||
if self.model_parallel:
|
|
||||||
torch.cuda.set_device(hidden_states.device)
|
|
||||||
# Ensure layer_past is on same device as hidden_states (might not be correct)
|
|
||||||
if layer_past is not None:
|
|
||||||
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
|
|
||||||
# Ensure that attention_mask is always on the same device as hidden_states
|
|
||||||
if torch.is_tensor(attention_mask):
|
|
||||||
attention_mask = attention_mask.to(hidden_states.device)
|
|
||||||
if isinstance(head_mask, torch.Tensor):
|
|
||||||
head_mask = head_mask.to(hidden_states.device)
|
|
||||||
if output_hidden_states:
|
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
|
||||||
|
|
||||||
def create_custom_forward(module):
|
|
||||||
def custom_forward(*inputs):
|
|
||||||
# None for past_key_value
|
|
||||||
return module(*inputs, use_cache, output_attentions)
|
|
||||||
|
|
||||||
return custom_forward
|
|
||||||
|
|
||||||
outputs = torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(block),
|
|
||||||
hidden_states,
|
|
||||||
None,
|
|
||||||
attention_mask,
|
|
||||||
head_mask[i],
|
|
||||||
encoder_hidden_states,
|
|
||||||
encoder_attention_mask,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
outputs = block(
|
|
||||||
hidden_states,
|
|
||||||
layer_past=layer_past,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
head_mask=head_mask[i],
|
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
if use_cache is True:
|
|
||||||
presents = presents + (outputs[1],)
|
|
||||||
|
|
||||||
if output_attentions:
|
|
||||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
|
||||||
if self.config.add_cross_attention:
|
|
||||||
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
|
||||||
|
|
||||||
# Model Parallel: If it's the last layer for that device, put things on the next device
|
|
||||||
if self.model_parallel:
|
|
||||||
for k, v in self.device_map.items():
|
|
||||||
if i == v[-1] and "cuda:" + str(k) != self.last_device:
|
|
||||||
hidden_states = hidden_states.to("cuda:" + str(k + 1))
|
|
||||||
|
|
||||||
# When sequence parallelism done, gather the output tensor in forward and split it in backward
|
|
||||||
hidden_states = gather_forward_split_backward(
|
|
||||||
hidden_states,
|
|
||||||
dim=1,
|
|
||||||
process_group=shard_config.sequence_parallel_process_group,
|
|
||||||
fp8_communication=shard_config.fp8_communication,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = self.ln_f(hidden_states)
|
|
||||||
hidden_states = hidden_states.view(output_shape)
|
|
||||||
# Add last hidden state
|
|
||||||
if output_hidden_states:
|
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
return tuple(
|
|
||||||
v
|
|
||||||
for v in [
|
|
||||||
hidden_states,
|
|
||||||
presents,
|
|
||||||
all_hidden_states,
|
|
||||||
all_self_attentions,
|
|
||||||
all_cross_attentions,
|
|
||||||
]
|
|
||||||
if v is not None
|
|
||||||
)
|
|
||||||
|
|
||||||
return BaseModelOutputWithPastAndCrossAttentions(
|
|
||||||
last_hidden_state=hidden_states,
|
|
||||||
past_key_values=presents,
|
|
||||||
hidden_states=all_hidden_states,
|
|
||||||
attentions=all_self_attentions,
|
|
||||||
cross_attentions=all_cross_attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
return forward
|
|
||||||
|
|
||||||
|
|
||||||
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|
||||||
from transformers import GPT2LMHeadModel
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self: GPT2LMHeadModel,
|
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
token_type_ids: Optional[torch.LongTensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
||||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
|
||||||
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
|
||||||
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
|
||||||
"""
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
transformer_outputs = self.transformer(
|
|
||||||
input_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
token_type_ids=token_type_ids,
|
|
||||||
position_ids=position_ids,
|
|
||||||
head_mask=head_mask,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
)
|
|
||||||
hidden_states = transformer_outputs[0]
|
|
||||||
|
|
||||||
lm_logits = self.lm_head(hidden_states)
|
|
||||||
loss = dist_cross_entropy(
|
|
||||||
labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (lm_logits,) + transformer_outputs[1:]
|
|
||||||
return ((loss,) + output) if loss is not None else output
|
|
||||||
|
|
||||||
return CausalLMOutputWithCrossAttentions(
|
|
||||||
loss=loss,
|
|
||||||
logits=lm_logits,
|
|
||||||
past_key_values=transformer_outputs.past_key_values,
|
|
||||||
hidden_states=transformer_outputs.hidden_states,
|
|
||||||
attentions=transformer_outputs.attentions,
|
|
||||||
cross_attentions=transformer_outputs.cross_attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
return forward
|
|
||||||
|
|
||||||
|
|
||||||
def get_jit_fused_gpt2_mlp_forward():
|
def get_jit_fused_gpt2_mlp_forward():
|
||||||
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP
|
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,6 @@ from transformers.models.llama.modeling_llama import (
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.shardformer.layer import AttnMaskType
|
|
||||||
from colossalai.shardformer.layer._operation import all_to_all_comm, gather_sp_output, split_forward_gather_backward
|
from colossalai.shardformer.layer._operation import all_to_all_comm, gather_sp_output, split_forward_gather_backward
|
||||||
from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag
|
from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag
|
||||||
from colossalai.shardformer.shard import ShardConfig
|
from colossalai.shardformer.shard import ShardConfig
|
||||||
|
@ -58,10 +57,7 @@ class LlamaPipelineForwards:
|
||||||
hidden_states: Optional[torch.FloatTensor] = None,
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
stage_index: Optional[List[int]] = None,
|
stage_index: Optional[List[int]] = None,
|
||||||
shard_config: ShardConfig = None,
|
shard_config: ShardConfig = None,
|
||||||
# Split output only when computing cross entropy using llama_for_causal_lm_forward
|
force_sp_gather: bool = True, # Set to false only when computing cross entropy
|
||||||
# or get_lm_forward_with_dist_cross_entropy
|
|
||||||
# Default to True to avoid bug when calling classification forward from huggingface
|
|
||||||
force_sp_output_gather: bool = True,
|
|
||||||
):
|
):
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
@ -78,8 +74,9 @@ class LlamaPipelineForwards:
|
||||||
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
disable_pp = stage_manager is None
|
||||||
# retrieve input_ids and inputs_embeds
|
# retrieve input_ids and inputs_embeds
|
||||||
if stage_manager.is_first_stage():
|
if disable_pp or stage_manager.is_first_stage():
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
elif input_ids is not None:
|
elif input_ids is not None:
|
||||||
|
@ -88,10 +85,10 @@ class LlamaPipelineForwards:
|
||||||
batch_size, seq_length, _ = inputs_embeds.shape[:2]
|
batch_size, seq_length, _ = inputs_embeds.shape[:2]
|
||||||
else:
|
else:
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
device = hidden_states.device
|
||||||
else:
|
else:
|
||||||
input_shape = hidden_states.shape[:-1]
|
input_shape = hidden_states.shape[:-1]
|
||||||
batch_size, seq_length = input_shape
|
batch_size, seq_length = input_shape
|
||||||
|
@ -101,8 +98,8 @@ class LlamaPipelineForwards:
|
||||||
sp_mode = shard_config.sequence_parallelism_mode
|
sp_mode = shard_config.sequence_parallelism_mode
|
||||||
sp_group = shard_config.sequence_parallel_process_group
|
sp_group = shard_config.sequence_parallel_process_group
|
||||||
sp_size = shard_config.sequence_parallel_size
|
sp_size = shard_config.sequence_parallel_size
|
||||||
if sp_mode == "all_to_all" and not stage_manager.is_first_stage():
|
# Generating full positions ids for modes that gather sequence before attn
|
||||||
# For generating full positions ids, as the states will be gather along the seq dim in the attention layer later.
|
if stage_manager and (sp_mode != "ring_attn" and not stage_manager.is_first_stage()):
|
||||||
seq_length *= sp_size
|
seq_length *= sp_size
|
||||||
|
|
||||||
past_seen_tokens = 0
|
past_seen_tokens = 0
|
||||||
|
@ -117,7 +114,6 @@ class LlamaPipelineForwards:
|
||||||
|
|
||||||
seq_length_with_past = seq_length + past_seen_tokens
|
seq_length_with_past = seq_length + past_seen_tokens
|
||||||
|
|
||||||
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||||
output_attentions = False
|
output_attentions = False
|
||||||
|
@ -130,14 +126,13 @@ class LlamaPipelineForwards:
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = cache_position.unsqueeze(0)
|
position_ids = cache_position.unsqueeze(0)
|
||||||
# embed positions, for the first stage, hidden_states is the input embeddings,
|
|
||||||
# for the other stages, hidden_states is the output of the previous stage
|
no_split_input = disable_pp or not stage_manager.is_first_stage()
|
||||||
if not stage_manager.is_first_stage() and sp_mode == "ring_attn":
|
if no_split_input and sp_mode == "ring_attn":
|
||||||
_, attn_kwargs, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group)
|
_, attn_kwargs, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group)
|
||||||
elif shard_config.enable_flash_attention:
|
elif shard_config.enable_flash_attention:
|
||||||
# in this case, attention_mask is a dict rather than a tensor
|
|
||||||
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
|
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
|
||||||
attn_kwargs = ColoAttention.prepare_attn_kwargs(
|
attn_kwargs: dict = ColoAttention.prepare_attn_kwargs(
|
||||||
mask_shape,
|
mask_shape,
|
||||||
hidden_states.dtype,
|
hidden_states.dtype,
|
||||||
hidden_states.device,
|
hidden_states.device,
|
||||||
|
@ -146,15 +141,15 @@ class LlamaPipelineForwards:
|
||||||
invert=(sp_mode != "ring_attn"),
|
invert=(sp_mode != "ring_attn"),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attn_kwargs = self._update_causal_mask(attention_mask, hidden_states, cache_position)
|
attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, hidden_states, cache_position)
|
||||||
|
|
||||||
# Support SP + PP
|
# Support SP + PP. Later stages have already received the split input.
|
||||||
# TODO: support padded casual cu_seqlens across stages
|
split_input = disable_pp or stage_manager.is_first_stage()
|
||||||
if stage_manager.is_first_stage():
|
if split_input:
|
||||||
# Ring Attention zigzag batch processing
|
# Ring Attention zigzag batch processing
|
||||||
if sp_mode == "ring_attn":
|
if sp_mode == "ring_attn":
|
||||||
assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention."
|
assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention."
|
||||||
if attn_kwargs["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL:
|
if not attention_mask.bool().all():
|
||||||
hidden_states, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch(
|
hidden_states, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch(
|
||||||
attention_mask, sp_group, hidden_states, position_ids
|
attention_mask, sp_group, hidden_states, position_ids
|
||||||
)
|
)
|
||||||
|
@ -181,8 +176,8 @@ class LlamaPipelineForwards:
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
next_decoder_cache = None
|
next_decoder_cache = None
|
||||||
|
start_idx, end_idx = (0, len(self.layers)) if disable_pp else (stage_index[0], stage_index[1])
|
||||||
|
|
||||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
|
||||||
num_ckpt_layers = 0
|
num_ckpt_layers = 0
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
num_ckpt_layers = end_idx - start_idx
|
num_ckpt_layers = end_idx - start_idx
|
||||||
|
@ -228,18 +223,16 @@ class LlamaPipelineForwards:
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attns += (layer_outputs[1],)
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
if stage_manager.is_last_stage():
|
if disable_pp or stage_manager.is_last_stage():
|
||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode):
|
if (not shard_config.parallel_output) or force_sp_gather or is_share_sp_tp(sp_mode): # noqa
|
||||||
hidden_states = gather_sp_output(
|
hidden_states = gather_sp_output(hidden_states, shard_config)
|
||||||
hidden_states, sp_group, sp_mode, fp8_communication=shard_config.fp8_communication
|
|
||||||
)
|
|
||||||
|
|
||||||
# add hidden states from the last decoder layer
|
# add hidden states from the last decoder layer
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
next_cache = next_decoder_cache if use_cache else None
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
if stage_manager.is_last_stage():
|
if disable_pp or stage_manager.is_last_stage():
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(
|
return tuple(
|
||||||
v
|
v
|
||||||
|
@ -257,7 +250,7 @@ class LlamaPipelineForwards:
|
||||||
hidden_states=all_hidden_states,
|
hidden_states=all_hidden_states,
|
||||||
attentions=all_self_attns,
|
attentions=all_self_attns,
|
||||||
)
|
)
|
||||||
# always return dict for imediate stage
|
# always return dict for intermediate stage
|
||||||
return {"hidden_states": hidden_states}
|
return {"hidden_states": hidden_states}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -323,7 +316,7 @@ class LlamaPipelineForwards:
|
||||||
# Split labels in a zigzag fashion too
|
# Split labels in a zigzag fashion too
|
||||||
sp_group = shard_config.sequence_parallel_process_group
|
sp_group = shard_config.sequence_parallel_process_group
|
||||||
if attention_mask.bool().all():
|
if attention_mask.bool().all():
|
||||||
labels = split_batch_zigzag(labels, sp_group, seq_dim=1)
|
labels = split_batch_zigzag(labels, sp_group, seq_dim=1, is_label=True)
|
||||||
else:
|
else:
|
||||||
# [B, max_seqlen // sp_size]
|
# [B, max_seqlen // sp_size]
|
||||||
labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True)
|
labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True)
|
||||||
|
@ -345,16 +338,17 @@ class LlamaPipelineForwards:
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
stage_index=stage_index,
|
stage_index=stage_index,
|
||||||
shard_config=shard_config,
|
shard_config=shard_config,
|
||||||
force_sp_output_gather=False,
|
force_sp_gather=False,
|
||||||
)
|
)
|
||||||
past_key_values = None
|
past_key_values = None
|
||||||
|
|
||||||
if stage_manager.is_last_stage():
|
disable_pp = stage_manager is None
|
||||||
|
if disable_pp or stage_manager.is_last_stage():
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
loss = dist_cross_entropy(
|
loss = None
|
||||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
|
if labels is not None:
|
||||||
)
|
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[1:]
|
output = (logits,) + outputs[1:]
|
||||||
|
@ -629,263 +623,3 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
||||||
return attn_output, attn_weights, past_key_value
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
return forward
|
return forward
|
||||||
|
|
||||||
|
|
||||||
def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
|
|
||||||
logger = logging.get_logger(__name__)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
# Split output only when computing cross entropy using llama_for_causal_lm_forward
|
|
||||||
# or get_lm_forward_with_dist_cross_entropy
|
|
||||||
# Default to True to avoid bug when calling classification forward from huggingface
|
|
||||||
force_sp_output_gather: bool = True,
|
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
# retrieve input_ids and inputs_embeds
|
|
||||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
||||||
raise ValueError(
|
|
||||||
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
|
||||||
)
|
|
||||||
|
|
||||||
if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training:
|
|
||||||
if use_cache:
|
|
||||||
logger.warning_once(
|
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
||||||
)
|
|
||||||
use_cache = False
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
|
||||||
|
|
||||||
past_seen_tokens = 0
|
|
||||||
seq_len = inputs_embeds.shape[1]
|
|
||||||
batch_size = inputs_embeds.shape[0]
|
|
||||||
if use_cache: # kept for BC (cache positions)
|
|
||||||
if not isinstance(past_key_values, StaticCache):
|
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
|
||||||
past_seen_tokens = past_key_values.get_seq_length()
|
|
||||||
|
|
||||||
if cache_position is None:
|
|
||||||
if isinstance(past_key_values, StaticCache):
|
|
||||||
raise ValueError("cache_position is a required argument when using StaticCache.")
|
|
||||||
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_len, device=inputs_embeds.device)
|
|
||||||
if position_ids is None:
|
|
||||||
position_ids = cache_position.unsqueeze(0)
|
|
||||||
|
|
||||||
if shard_config.enable_flash_attention:
|
|
||||||
mask_shape = (batch_size, 1, seq_len, past_seen_tokens + seq_len)
|
|
||||||
attn_kwargs: dict = ColoAttention.prepare_attn_kwargs(
|
|
||||||
mask_shape,
|
|
||||||
inputs_embeds.dtype,
|
|
||||||
inputs_embeds.device,
|
|
||||||
q_padding_mask=attention_mask,
|
|
||||||
is_causal=True,
|
|
||||||
invert=(sp_mode != "ring_attn"),
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
|
|
||||||
|
|
||||||
# Ring Attention zigzag batch processing
|
|
||||||
if sp_mode == "ring_attn":
|
|
||||||
assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention."
|
|
||||||
if attn_kwargs["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL:
|
|
||||||
inputs_embeds, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch(
|
|
||||||
attention_mask, sp_group, inputs_embeds, position_ids
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
inputs_embeds, position_ids = split_batch_zigzag([inputs_embeds, position_ids], sp_group)
|
|
||||||
attn_kwargs = {"attention_mask_type": attn_kwargs["attention_mask_type"]} # drop redundant tensors
|
|
||||||
|
|
||||||
elif is_share_sp_tp(sp_mode):
|
|
||||||
inputs_embeds = split_forward_gather_backward(
|
|
||||||
inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
|
||||||
)
|
|
||||||
elif sp_mode == "all_to_all":
|
|
||||||
inputs_embeds = split_forward_gather_backward(
|
|
||||||
inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
|
|
||||||
)
|
|
||||||
hidden_states = inputs_embeds
|
|
||||||
|
|
||||||
# decoder layers
|
|
||||||
all_hidden_states = () if output_hidden_states else None
|
|
||||||
all_self_attns = () if output_attentions else None
|
|
||||||
next_decoder_cache = None
|
|
||||||
|
|
||||||
for decoder_layer in self.layers:
|
|
||||||
if output_hidden_states:
|
|
||||||
all_hidden_states += (hidden_states,)
|
|
||||||
if self.gradient_checkpointing and self.training:
|
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
|
||||||
decoder_layer.__call__,
|
|
||||||
hidden_states,
|
|
||||||
attn_kwargs,
|
|
||||||
position_ids,
|
|
||||||
past_key_values,
|
|
||||||
output_attentions,
|
|
||||||
use_cache,
|
|
||||||
cache_position,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=attn_kwargs,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_values,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cache_position=cache_position,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
|
||||||
|
|
||||||
if use_cache:
|
|
||||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
|
||||||
|
|
||||||
if output_attentions:
|
|
||||||
all_self_attns += (layer_outputs[1],)
|
|
||||||
|
|
||||||
hidden_states = self.norm(hidden_states)
|
|
||||||
# Cases that don't support parallelizing cross entropy computation along sequence
|
|
||||||
if (not shard_config.parallel_output) or is_share_sp_tp(sp_mode) or force_sp_output_gather:
|
|
||||||
hidden_states = gather_sp_output(
|
|
||||||
hidden_states, sp_group, sp_mode, fp8_communication=shard_config.fp8_communication
|
|
||||||
)
|
|
||||||
|
|
||||||
# add hidden states from the last decoder layer
|
|
||||||
if output_hidden_states:
|
|
||||||
all_hidden_states += (hidden_states,)
|
|
||||||
|
|
||||||
next_cache = None
|
|
||||||
if use_cache:
|
|
||||||
next_cache = (
|
|
||||||
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
|
|
||||||
)
|
|
||||||
if not return_dict:
|
|
||||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
|
||||||
|
|
||||||
return BaseModelOutputWithPast(
|
|
||||||
last_hidden_state=hidden_states,
|
|
||||||
past_key_values=next_cache,
|
|
||||||
hidden_states=all_hidden_states,
|
|
||||||
attentions=all_self_attns,
|
|
||||||
)
|
|
||||||
|
|
||||||
return forward
|
|
||||||
|
|
||||||
|
|
||||||
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|
||||||
from transformers import LlamaForCausalLM
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self: LlamaForCausalLM,
|
|
||||||
input_ids: torch.LongTensor = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
Args:
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
|
||||||
|
|
||||||
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
|
||||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
|
||||||
|
|
||||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
|
||||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
||||||
|
|
||||||
>>> # Generate
|
|
||||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
||||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
||||||
```"""
|
|
||||||
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
if shard_config.sequence_parallelism_mode == "ring_attn" and shard_config.parallel_output:
|
|
||||||
# Special processing: Split labels in a zigzag fashion too
|
|
||||||
sp_group = shard_config.sequence_parallel_process_group
|
|
||||||
if attention_mask.bool().all():
|
|
||||||
labels = split_batch_zigzag(labels, sp_group, seq_dim=1, is_label=True)
|
|
||||||
else:
|
|
||||||
# [B, max_seq_len // sp_size]
|
|
||||||
labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True)
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
force_sp_output_gather=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
if self.config.pretraining_tp > 1:
|
|
||||||
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
|
|
||||||
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
|
|
||||||
logits = torch.cat(logits, dim=-1)
|
|
||||||
else:
|
|
||||||
logits = self.lm_head(hidden_states)
|
|
||||||
logits = logits.float()
|
|
||||||
loss = dist_cross_entropy(
|
|
||||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
|
|
||||||
)
|
|
||||||
if not return_dict:
|
|
||||||
output = (logits,) + outputs[1:]
|
|
||||||
return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
return forward
|
|
||||||
|
|
|
@ -274,10 +274,9 @@ class MistralForwards:
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
logits = logits.float()
|
logits = logits.float()
|
||||||
|
loss = None
|
||||||
loss = dist_cross_entropy(
|
if labels is not None:
|
||||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
|
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype)
|
||||||
)
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[1:]
|
output = (logits,) + outputs[1:]
|
||||||
|
@ -687,10 +686,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
logits = logits.float()
|
logits = logits.float()
|
||||||
|
loss = None
|
||||||
loss = dist_cross_entropy(
|
if labels is not None:
|
||||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
|
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype)
|
||||||
)
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[1:]
|
output = (logits,) + outputs[1:]
|
||||||
|
|
|
@ -330,12 +330,13 @@ class OPTPipelineForwards:
|
||||||
)
|
)
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage():
|
||||||
logits = self.lm_head(outputs[0]).contiguous()
|
logits = self.lm_head(outputs[0]).contiguous()
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
loss = dist_cross_entropy(
|
loss = dist_cross_entropy(
|
||||||
labels,
|
labels,
|
||||||
logits,
|
logits,
|
||||||
shard_config,
|
shard_config,
|
||||||
self.lm_head.out_features,
|
self.lm_head.out_features,
|
||||||
self.config.vocab_size,
|
|
||||||
self.model.decoder.dtype,
|
self.model.decoder.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -955,9 +956,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||||
)
|
)
|
||||||
|
|
||||||
logits = self.lm_head(outputs[0]).contiguous()
|
logits = self.lm_head(outputs[0]).contiguous()
|
||||||
loss = dist_cross_entropy(
|
loss = None
|
||||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.decoder.dtype
|
if labels is not None:
|
||||||
)
|
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.decoder.dtype)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[1:]
|
output = (logits,) + outputs[1:]
|
||||||
|
|
|
@ -32,14 +32,12 @@ except ImportError:
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.shardformer.layer._operation import (
|
from colossalai.shardformer.layer._operation import all_to_all_comm, split_forward_gather_backward
|
||||||
all_to_all_comm,
|
|
||||||
gather_forward_split_backward,
|
|
||||||
split_forward_gather_backward,
|
|
||||||
)
|
|
||||||
from colossalai.shardformer.shard import ShardConfig
|
from colossalai.shardformer.shard import ShardConfig
|
||||||
|
|
||||||
from ..layer import ColoAttention, dist_cross_entropy
|
from ..layer import ColoAttention, dist_cross_entropy
|
||||||
|
from ..layer._operation import gather_sp_output
|
||||||
|
from ..layer.utils import is_share_sp_tp
|
||||||
|
|
||||||
|
|
||||||
class Qwen2PipelineForwards:
|
class Qwen2PipelineForwards:
|
||||||
|
@ -64,6 +62,7 @@ class Qwen2PipelineForwards:
|
||||||
hidden_states: Optional[torch.FloatTensor] = None,
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
stage_index: Optional[List[int]] = None,
|
stage_index: Optional[List[int]] = None,
|
||||||
shard_config: ShardConfig = None,
|
shard_config: ShardConfig = None,
|
||||||
|
force_sp_output_gather: bool = True,
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
@ -115,6 +114,14 @@ class Qwen2PipelineForwards:
|
||||||
past_key_values_length = past_key_values[0][0].shape[2]
|
past_key_values_length = past_key_values[0][0].shape[2]
|
||||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||||
|
|
||||||
|
# Support SP + PP
|
||||||
|
sp_size = shard_config.sequence_parallel_size
|
||||||
|
sp_group = shard_config.sequence_parallel_process_group
|
||||||
|
sp_mode = shard_config.sequence_parallelism_mode
|
||||||
|
# For generating full positions ids (the states will be gathered along the seq dim before attention fwd).
|
||||||
|
if sp_mode != "ring_attn" and not stage_manager.is_first_stage():
|
||||||
|
seq_length *= sp_size
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
position_ids = torch.arange(
|
position_ids = torch.arange(
|
||||||
|
@ -151,7 +158,6 @@ class Qwen2PipelineForwards:
|
||||||
elif self._attn_implementation == "sdpa" and not output_attentions:
|
elif self._attn_implementation == "sdpa" and not output_attentions:
|
||||||
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
||||||
# the manual implementation that requires a 4D causal mask in all cases.
|
# the manual implementation that requires a 4D causal mask in all cases.
|
||||||
|
|
||||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||||
attention_mask,
|
attention_mask,
|
||||||
(batch_size, seq_length),
|
(batch_size, seq_length),
|
||||||
|
@ -160,7 +166,6 @@ class Qwen2PipelineForwards:
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# 4d mask is passed through the layers
|
# 4d mask is passed through the layers
|
||||||
|
|
||||||
attention_mask = _prepare_4d_causal_attention_mask(
|
attention_mask = _prepare_4d_causal_attention_mask(
|
||||||
attention_mask,
|
attention_mask,
|
||||||
(batch_size, seq_length),
|
(batch_size, seq_length),
|
||||||
|
@ -169,21 +174,20 @@ class Qwen2PipelineForwards:
|
||||||
sliding_window=self.config.sliding_window,
|
sliding_window=self.config.sliding_window,
|
||||||
)
|
)
|
||||||
|
|
||||||
if shard_config and shard_config.enable_sequence_parallelism:
|
if stage_manager.is_first_stage():
|
||||||
if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
|
if shard_config.enable_sequence_parallelism:
|
||||||
|
if is_share_sp_tp(sp_mode):
|
||||||
hidden_states = split_forward_gather_backward(
|
hidden_states = split_forward_gather_backward(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
dim=1,
|
dim=1,
|
||||||
process_group=shard_config.tensor_parallel_process_group,
|
process_group=sp_group,
|
||||||
fp8_communication=shard_config.fp8_communication,
|
|
||||||
)
|
)
|
||||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
elif sp_mode == "all_to_all":
|
||||||
hidden_states = split_forward_gather_backward(
|
hidden_states = split_forward_gather_backward(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
dim=1,
|
dim=1,
|
||||||
process_group=shard_config.sequence_parallel_process_group,
|
process_group=sp_group,
|
||||||
grad_scale=1 / shard_config.sequence_parallel_size,
|
grad_scale=1 / sp_size,
|
||||||
fp8_communication=shard_config.fp8_communication,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# decoder layers
|
# decoder layers
|
||||||
|
@ -241,23 +245,10 @@ class Qwen2PipelineForwards:
|
||||||
|
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage():
|
||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
|
if shard_config.enable_sequence_parallelism:
|
||||||
|
if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode):
|
||||||
|
hidden_states = gather_sp_output(hidden_states, shard_config)
|
||||||
|
|
||||||
if shard_config and shard_config.enable_sequence_parallelism:
|
|
||||||
if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
|
|
||||||
hidden_states = gather_forward_split_backward(
|
|
||||||
hidden_states,
|
|
||||||
dim=1,
|
|
||||||
process_group=shard_config.tensor_parallel_process_group,
|
|
||||||
fp8_communication=shard_config.fp8_communication,
|
|
||||||
)
|
|
||||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
|
||||||
hidden_states = gather_forward_split_backward(
|
|
||||||
hidden_states,
|
|
||||||
dim=1,
|
|
||||||
process_group=shard_config.sequence_parallel_process_group,
|
|
||||||
grad_scale=shard_config.sequence_parallel_size,
|
|
||||||
fp8_communication=shard_config.fp8_communication,
|
|
||||||
)
|
|
||||||
# add hidden states from the last decoder layer
|
# add hidden states from the last decoder layer
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
@ -351,15 +342,18 @@ class Qwen2PipelineForwards:
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
stage_index=stage_index,
|
stage_index=stage_index,
|
||||||
shard_config=shard_config,
|
shard_config=shard_config,
|
||||||
|
force_sp_output_gather=False,
|
||||||
)
|
)
|
||||||
past_key_values = None
|
past_key_values = None
|
||||||
|
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage():
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
|
if hidden_states.shape[1] == 2:
|
||||||
|
pass
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
loss = dist_cross_entropy(
|
loss = None
|
||||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, logits.dtype
|
if labels is not None:
|
||||||
)
|
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, logits.dtype)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[1:]
|
output = (logits,) + outputs[1:]
|
||||||
|
@ -541,7 +535,6 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
||||||
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
||||||
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
|
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
|
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
|
||||||
|
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
|
@ -635,6 +628,7 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
force_sp_output_gather: bool = True,
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
|
@ -750,14 +744,9 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No
|
||||||
|
|
||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
if sp_mode == "ring" or sp_mode == "split_gather":
|
if shard_config.enable_sequence_parallelism:
|
||||||
hidden_states = gather_forward_split_backward(
|
if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode):
|
||||||
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
hidden_states = gather_sp_output(hidden_states, shard_config)
|
||||||
)
|
|
||||||
elif sp_mode == "all_to_all":
|
|
||||||
hidden_states = gather_forward_split_backward(
|
|
||||||
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
|
|
||||||
)
|
|
||||||
|
|
||||||
# add hidden states from the last decoder layer
|
# add hidden states from the last decoder layer
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
|
@ -834,14 +823,15 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
|
force_sp_output_gather=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
logits = logits.float()
|
logits = logits.float()
|
||||||
loss = dist_cross_entropy(
|
loss = None
|
||||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, logits.dtype
|
if labels is not None:
|
||||||
)
|
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, logits.dtype)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[1:]
|
output = (logits,) + outputs[1:]
|
||||||
|
|
|
@ -64,7 +64,7 @@ class ChatGLMPolicy(Policy):
|
||||||
|
|
||||||
if sp_mode == "ring":
|
if sp_mode == "ring":
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"For ChatGLM2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather"
|
f"For ChatGLM2, sequence parallelism doesn't support mode {sp_mode} yet, will set to be split_gather"
|
||||||
)
|
)
|
||||||
sp_mode = "split_gather"
|
sp_mode = "split_gather"
|
||||||
overlap = self.shard_config.enable_sequence_overlap
|
overlap = self.shard_config.enable_sequence_overlap
|
||||||
|
|
|
@ -6,14 +6,7 @@ from torch import Tensor, nn
|
||||||
|
|
||||||
import colossalai.shardformer.layer as col_nn
|
import colossalai.shardformer.layer as col_nn
|
||||||
|
|
||||||
from ..modeling.gpt2 import (
|
from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward, get_jit_fused_gpt2_mlp_forward
|
||||||
GPT2PipelineForwards,
|
|
||||||
get_gpt2_flash_attention_forward,
|
|
||||||
get_gpt_model_forward_for_flash_attn,
|
|
||||||
get_jit_fused_gpt2_mlp_forward,
|
|
||||||
get_lm_forward_with_dist_cross_entropy,
|
|
||||||
gpt2_sequence_parallel_forward_fn,
|
|
||||||
)
|
|
||||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -71,18 +64,10 @@ class GPT2Policy(Policy):
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"For GPT2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather"
|
f"For GPT2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather"
|
||||||
)
|
)
|
||||||
sp_mode = "split_gather"
|
self.shard_config.sequence_parallelism_mode = sp_mode = "split_gather"
|
||||||
overlap = self.shard_config.enable_sequence_overlap
|
overlap = self.shard_config.enable_sequence_overlap
|
||||||
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
||||||
use_flash_attention = self.shard_config.enable_flash_attention
|
use_flash_attention = self.shard_config.enable_flash_attention
|
||||||
# todo: currently sp cannot be used with flashattention
|
|
||||||
if sp_mode in ["split_gather", "ring", "all_to_all"]:
|
|
||||||
if use_flash_attention:
|
|
||||||
warnings.warn(
|
|
||||||
f"Sequence parallelism mode {sp_mode} cannot be used with FlashAttention, will disable FlashAttention automatically."
|
|
||||||
)
|
|
||||||
self.shard_config.enable_flash_attention = False
|
|
||||||
use_flash_attention = False
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
assert (
|
assert (
|
||||||
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||||
|
@ -211,18 +196,16 @@ class GPT2Policy(Policy):
|
||||||
if use_flash_attention:
|
if use_flash_attention:
|
||||||
self.append_or_create_method_replacement(
|
self.append_or_create_method_replacement(
|
||||||
description={
|
description={
|
||||||
"forward": get_gpt2_flash_attention_forward(),
|
"forward": get_gpt2_flash_attention_forward(shard_config=self.shard_config),
|
||||||
},
|
},
|
||||||
policy=policy,
|
policy=policy,
|
||||||
target_key=attn_cls,
|
target_key=attn_cls,
|
||||||
)
|
)
|
||||||
if not self.shard_config.pipeline_stage_manager:
|
|
||||||
policy[GPT2Model].method_replacement = {
|
|
||||||
"forward": get_gpt_model_forward_for_flash_attn(self.shard_config)
|
|
||||||
}
|
|
||||||
|
|
||||||
if sp_mode is not None:
|
if not self.shard_config.pipeline_stage_manager and self.shard_config.enable_sequence_parallelism:
|
||||||
policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)}
|
policy[GPT2Model].method_replacement = {
|
||||||
|
"forward": partial(GPT2PipelineForwards.gpt2_model_forward, shard_config=self.shard_config)
|
||||||
|
}
|
||||||
|
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
|
@ -328,40 +311,39 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
|
||||||
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
||||||
|
|
||||||
module_policy = super().module_policy()
|
module_policy = super().module_policy()
|
||||||
|
module_policy[GPT2LMHeadModel] = ModulePolicyDescription()
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
addon_module = {
|
self.append_or_create_submodule_replacement(
|
||||||
GPT2LMHeadModel: ModulePolicyDescription(
|
description=SubModuleReplacementDescription(
|
||||||
sub_module_replacement=[
|
|
||||||
SubModuleReplacementDescription(
|
|
||||||
suffix="lm_head",
|
suffix="lm_head",
|
||||||
target_module=col_nn.VocabParallelLMHead1D,
|
target_module=col_nn.VocabParallelLMHead1D,
|
||||||
kwargs={
|
kwargs={
|
||||||
"gather_output": False,
|
"gather_output": False,
|
||||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||||
"fp8_communication": self.shard_config.fp8_communication,
|
|
||||||
},
|
},
|
||||||
|
),
|
||||||
|
policy=module_policy,
|
||||||
|
target_key=GPT2LMHeadModel,
|
||||||
)
|
)
|
||||||
],
|
|
||||||
)
|
|
||||||
}
|
|
||||||
if self.shard_config.parallel_output:
|
|
||||||
addon_module[GPT2LMHeadModel].method_replacement = {
|
|
||||||
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
|
|
||||||
}
|
|
||||||
else:
|
else:
|
||||||
addon_module = {
|
self.append_or_create_submodule_replacement(
|
||||||
GPT2LMHeadModel: ModulePolicyDescription(
|
description=SubModuleReplacementDescription(
|
||||||
sub_module_replacement=[
|
|
||||||
SubModuleReplacementDescription(
|
|
||||||
suffix="lm_head",
|
suffix="lm_head",
|
||||||
target_module=col_nn.PaddingLMHead,
|
target_module=col_nn.PaddingLMHead,
|
||||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||||
|
),
|
||||||
|
policy=module_policy,
|
||||||
|
target_key=GPT2LMHeadModel,
|
||||||
)
|
)
|
||||||
]
|
|
||||||
|
if self.shard_config.parallel_output:
|
||||||
|
self.append_or_create_method_replacement(
|
||||||
|
description={
|
||||||
|
"forward": partial(GPT2PipelineForwards.gpt2_lmhead_model_forward, shard_config=self.shard_config)
|
||||||
|
},
|
||||||
|
policy=module_policy,
|
||||||
|
target_key=GPT2LMHeadModel,
|
||||||
)
|
)
|
||||||
}
|
|
||||||
module_policy.update(addon_module)
|
|
||||||
|
|
||||||
if self.pipeline_stage_manager is not None:
|
if self.pipeline_stage_manager is not None:
|
||||||
self.set_pipeline_forward(
|
self.set_pipeline_forward(
|
||||||
|
|
|
@ -16,12 +16,7 @@ from colossalai.shardformer.layer import (
|
||||||
VocabParallelLMHead1D,
|
VocabParallelLMHead1D,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..modeling.llama import (
|
from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward
|
||||||
LlamaPipelineForwards,
|
|
||||||
get_llama_flash_attention_forward,
|
|
||||||
get_llama_flash_attention_model_forward,
|
|
||||||
get_lm_forward_with_dist_cross_entropy,
|
|
||||||
)
|
|
||||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
__all__ = ["LlamaPolicy", "LlamaForCausalLMPolicy", "LlamaForSequenceClassificationPolicy"]
|
__all__ = ["LlamaPolicy", "LlamaForCausalLMPolicy", "LlamaForSequenceClassificationPolicy"]
|
||||||
|
@ -99,11 +94,9 @@ class LlamaPolicy(Policy):
|
||||||
if self.pipeline_stage_manager is None:
|
if self.pipeline_stage_manager is None:
|
||||||
self.append_or_create_method_replacement(
|
self.append_or_create_method_replacement(
|
||||||
description={
|
description={
|
||||||
"forward": get_llama_flash_attention_model_forward(
|
"forward": partial(
|
||||||
self.shard_config,
|
LlamaPipelineForwards.llama_model_forward,
|
||||||
sp_mode=sp_mode,
|
shard_config=self.shard_config,
|
||||||
sp_size=sp_size,
|
|
||||||
sp_group=sp_group,
|
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
policy=policy,
|
policy=policy,
|
||||||
|
@ -351,7 +344,8 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
||||||
elif self.shard_config.enable_tensor_parallelism or self.shard_config.enable_sequence_parallelism:
|
elif self.shard_config.enable_tensor_parallelism or self.shard_config.enable_sequence_parallelism:
|
||||||
# Compute loss distributedly along the sequence dimension
|
# Compute loss distributedly along the sequence dimension
|
||||||
new_item[LlamaForCausalLM].method_replacement = {
|
new_item[LlamaForCausalLM].method_replacement = {
|
||||||
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
|
# "forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
|
||||||
|
"forward": partial(LlamaPipelineForwards.llama_for_causal_lm_forward, shard_config=self.shard_config)
|
||||||
}
|
}
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,7 @@ MODEL_CONFIGS = {
|
||||||
"118M": GPT2Config(activation_function="gelu"),
|
"118M": GPT2Config(activation_function="gelu"),
|
||||||
"338M": GPT2Config(n_embd=1024, n_head=16, n_layer=24, activation_function="gelu"),
|
"338M": GPT2Config(n_embd=1024, n_head=16, n_layer=24, activation_function="gelu"),
|
||||||
"738M": GPT2Config(n_embd=1280, n_head=20, n_layer=36, activation_function="gelu"),
|
"738M": GPT2Config(n_embd=1280, n_head=20, n_layer=36, activation_function="gelu"),
|
||||||
"6.21B": GPT2Config(n_embd=4096, n_head=32, n_layer=32, n_positions=4096, activation_function="gelu"),
|
"6.21B": GPT2Config(n_embd=4096, n_head=32, n_layer=32, n_positions=32768, activation_function="gelu"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -60,6 +60,8 @@ def main():
|
||||||
parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size")
|
parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size")
|
||||||
parser.add_argument("--extra_dp", type=int, default=1, help="Extra data parallel size, used for Gemini")
|
parser.add_argument("--extra_dp", type=int, default=1, help="Extra data parallel size, used for Gemini")
|
||||||
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size")
|
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size")
|
||||||
|
parser.add_argument("--sp", type=int, default=1, help="Sequence parallel size")
|
||||||
|
parser.add_argument("--sp_mode", type=str, default="ring_attn", help="Sequence parallel mode")
|
||||||
parser.add_argument("--mbs", type=int, default=1)
|
parser.add_argument("--mbs", type=int, default=1)
|
||||||
parser.add_argument("--zero", type=int, default=0)
|
parser.add_argument("--zero", type=int, default=0)
|
||||||
parser.add_argument("--pp_style", type=str, default="1f1b")
|
parser.add_argument("--pp_style", type=str, default="1f1b")
|
||||||
|
@ -129,6 +131,9 @@ def main():
|
||||||
tp_size=args.tp,
|
tp_size=args.tp,
|
||||||
pp_size=args.pp,
|
pp_size=args.pp,
|
||||||
pp_style=args.pp_style,
|
pp_style=args.pp_style,
|
||||||
|
sp_size=args.sp,
|
||||||
|
sequence_parallelism_mode=args.sp_mode,
|
||||||
|
enable_sequence_parallelism=True,
|
||||||
zero_stage=args.zero,
|
zero_stage=args.zero,
|
||||||
num_model_chunks=args.num_model_chunks,
|
num_model_chunks=args.num_model_chunks,
|
||||||
enable_all_optimization=True,
|
enable_all_optimization=True,
|
||||||
|
@ -214,6 +219,8 @@ def main():
|
||||||
performance_evaluator.on_step_start(step)
|
performance_evaluator.on_step_start(step)
|
||||||
outputs = model(**batch)
|
outputs = model(**batch)
|
||||||
loss = outputs[0]
|
loss = outputs[0]
|
||||||
|
del outputs
|
||||||
|
|
||||||
booster.backward(loss, optimizer)
|
booster.backward(loss, optimizer)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
|
@ -6,7 +6,6 @@ import torch.distributed as dist
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler
|
from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler
|
||||||
|
|
||||||
from colossalai.accelerator import get_accelerator
|
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
|
|
||||||
|
|
||||||
|
@ -22,8 +21,11 @@ def divide(x: float, y: float) -> float:
|
||||||
def all_reduce_mean(x: float, world_size: int) -> float:
|
def all_reduce_mean(x: float, world_size: int) -> float:
|
||||||
if world_size == 1:
|
if world_size == 1:
|
||||||
return x
|
return x
|
||||||
tensor = torch.tensor([x], device=get_accelerator().get_current_device())
|
|
||||||
dist.all_reduce(tensor)
|
# Use CPU tensor to avoid OOM/weird NCCl error
|
||||||
|
gloo_group = dist.new_group(backend="gloo")
|
||||||
|
tensor = torch.tensor([x], device="cpu")
|
||||||
|
dist.all_reduce(tensor, group=gloo_group)
|
||||||
tensor = tensor / world_size
|
tensor = tensor / world_size
|
||||||
return tensor.item()
|
return tensor.item()
|
||||||
|
|
||||||
|
|
|
@ -27,7 +27,16 @@ def data_gen_for_lm():
|
||||||
# LM data gen
|
# LM data gen
|
||||||
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
|
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
|
||||||
data = data_gen()
|
data = data_gen()
|
||||||
data["labels"] = data["input_ids"].clone()
|
|
||||||
|
# Test padded sequence for Ring Attention
|
||||||
|
padding = torch.zeros(1, data["input_ids"].shape[1] // 2, dtype=torch.long)
|
||||||
|
data["input_ids"] = torch.cat([data["input_ids"], padding], dim=1)
|
||||||
|
data["attention_mask"] = torch.cat([data["attention_mask"], padding], dim=1)
|
||||||
|
|
||||||
|
ignore_idx = -100
|
||||||
|
labels = data["input_ids"].clone()
|
||||||
|
labels[~data["attention_mask"].bool()] = ignore_idx
|
||||||
|
data["labels"] = labels
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -157,7 +157,6 @@ def build_model_from_hybrid_plugin(
|
||||||
sharded_optimizer = sharded_optim_class(sharded_model.parameters(), lr=1e-3)
|
sharded_optimizer = sharded_optim_class(sharded_model.parameters(), lr=1e-3)
|
||||||
|
|
||||||
criterion = loss_fn
|
criterion = loss_fn
|
||||||
|
|
||||||
plugin = pluggin_cls(**test_config)
|
plugin = pluggin_cls(**test_config)
|
||||||
booster = Booster(plugin=plugin)
|
booster = Booster(plugin=plugin)
|
||||||
|
|
||||||
|
@ -323,7 +322,6 @@ def check_output_hidden_state(
|
||||||
sp_size = shard_config.sequence_parallel_size
|
sp_size = shard_config.sequence_parallel_size
|
||||||
if org_hidden_state.shape[seq_dim] == sharded_hidden_state.shape[seq_dim] * sp_size:
|
if org_hidden_state.shape[seq_dim] == sharded_hidden_state.shape[seq_dim] * sp_size:
|
||||||
org_hidden_state = org_hidden_state.chunk(sp_size, dim=seq_dim)[dist.get_rank(sp_group)]
|
org_hidden_state = org_hidden_state.chunk(sp_size, dim=seq_dim)[dist.get_rank(sp_group)]
|
||||||
|
|
||||||
assert_close(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol)
|
assert_close(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -136,6 +136,18 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
@parameterize(
|
@parameterize(
|
||||||
"test_config",
|
"test_config",
|
||||||
[
|
[
|
||||||
|
{
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 2,
|
||||||
|
"num_microbatches": 2,
|
||||||
|
"enable_sequence_parallelism": True,
|
||||||
|
"sequence_parallelism_mode": "split_gather",
|
||||||
|
"enable_flash_attention": True,
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"zero_stage": 1,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
{ # Ulysess + Flash attention
|
{ # Ulysess + Flash attention
|
||||||
"tp_size": 1,
|
"tp_size": 1,
|
||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
|
@ -149,19 +161,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
"initial_scale": 1,
|
"initial_scale": 1,
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"tp_size": 2,
|
|
||||||
"pp_size": 2,
|
|
||||||
"sp_size": 2,
|
|
||||||
"num_microbatches": 2,
|
|
||||||
"enable_sequence_parallelism": True,
|
|
||||||
"sequence_parallelism_mode": "split_gather",
|
|
||||||
"enable_flash_attention": True,
|
|
||||||
"use_lazy_init": True,
|
|
||||||
"zero_stage": 1,
|
|
||||||
"precision": "fp16",
|
|
||||||
"initial_scale": 1,
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"tp_size": 1,
|
"tp_size": 1,
|
||||||
"pp_size": 1,
|
"pp_size": 1,
|
||||||
|
@ -174,17 +173,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
"initial_scale": 1,
|
"initial_scale": 1,
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"tp_size": 4,
|
|
||||||
"pp_size": 1,
|
|
||||||
"num_microbatches": 1,
|
|
||||||
"enable_sequence_parallelism": True,
|
|
||||||
"sequence_parallelism_mode": "ring",
|
|
||||||
"enable_flash_attention": False,
|
|
||||||
"use_lazy_init": True,
|
|
||||||
"precision": "fp32",
|
|
||||||
"initial_scale": 1,
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"tp_size": 4,
|
"tp_size": 4,
|
||||||
"pp_size": 1,
|
"pp_size": 1,
|
||||||
|
@ -248,7 +236,11 @@ def run_chatglm_test(test_config):
|
||||||
loss_fn,
|
loss_fn,
|
||||||
_,
|
_,
|
||||||
) in sub_model_zoo.items():
|
) in sub_model_zoo.items():
|
||||||
|
try:
|
||||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Test config failed for model {name}: {test_config}")
|
||||||
|
raise e
|
||||||
|
|
||||||
clear_layout_converter()
|
clear_layout_converter()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
|
@ -125,7 +125,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
atol, rtol = 5e-3, 5e-3
|
atol, rtol = 5e-3, 5e-3
|
||||||
|
|
||||||
if org_model.__class__.__name__ == "CohereModel":
|
if org_model.__class__.__name__ == "CohereModel":
|
||||||
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
|
check_output_hidden_state(
|
||||||
|
org_output,
|
||||||
|
sharded_output,
|
||||||
|
stage_manager,
|
||||||
|
atol=atol,
|
||||||
|
rtol=rtol,
|
||||||
|
shard_config=booster.plugin.shard_config,
|
||||||
|
)
|
||||||
|
|
||||||
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
@ -274,7 +281,11 @@ def run_command_test(test_config):
|
||||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_causal_lm")
|
sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_causal_lm")
|
||||||
|
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
|
try:
|
||||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed test config: {test_config}")
|
||||||
|
raise e
|
||||||
|
|
||||||
clear_layout_converter()
|
clear_layout_converter()
|
||||||
Randomizer.reset_index()
|
Randomizer.reset_index()
|
||||||
|
|
|
@ -100,7 +100,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
atol, rtol = 5e-3, 5e-3
|
atol, rtol = 5e-3, 5e-3
|
||||||
|
|
||||||
if org_model.__class__.__name__ == "GPT2Model":
|
if org_model.__class__.__name__ == "GPT2Model":
|
||||||
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
|
check_output_hidden_state(
|
||||||
|
org_output,
|
||||||
|
sharded_output,
|
||||||
|
stage_manager,
|
||||||
|
atol=atol,
|
||||||
|
rtol=rtol,
|
||||||
|
shard_config=booster.plugin.shard_config,
|
||||||
|
)
|
||||||
|
|
||||||
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
@ -132,14 +139,27 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
"test_config",
|
"test_config",
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"tp_size": 4,
|
"sp_size": 2,
|
||||||
"pp_size": 1,
|
"tp_size": 1,
|
||||||
"num_microbatches": 1,
|
"pp_size": 2,
|
||||||
"enable_sequence_parallelism": True,
|
"enable_sequence_parallelism": True,
|
||||||
"sequence_parallelism_mode": "ring",
|
"sequence_parallelism_mode": "ring_attn",
|
||||||
"enable_flash_attention": False,
|
"num_microbatches": 2,
|
||||||
|
"enable_all_optimization": True,
|
||||||
"use_lazy_init": True,
|
"use_lazy_init": True,
|
||||||
"precision": "fp32",
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"sp_size": 2,
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 1,
|
||||||
|
"enable_sequence_parallelism": True,
|
||||||
|
"sequence_parallelism_mode": "ring_attn",
|
||||||
|
"num_microbatches": 1,
|
||||||
|
"enable_all_optimization": True,
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"precision": "fp16",
|
||||||
"initial_scale": 1,
|
"initial_scale": 1,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -148,7 +168,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
"num_microbatches": 1,
|
"num_microbatches": 1,
|
||||||
"enable_sequence_parallelism": True,
|
"enable_sequence_parallelism": True,
|
||||||
"sequence_parallelism_mode": "split_gather",
|
"sequence_parallelism_mode": "split_gather",
|
||||||
"enable_flash_attention": False,
|
"enable_flash_attention": True,
|
||||||
"use_lazy_init": True,
|
"use_lazy_init": True,
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
"initial_scale": 1,
|
"initial_scale": 1,
|
||||||
|
@ -156,7 +176,18 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
{
|
{
|
||||||
"tp_size": 2,
|
"tp_size": 2,
|
||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
"num_microbatches": 4,
|
"num_microbatches": 2,
|
||||||
|
"enable_sequence_parallelism": True,
|
||||||
|
"sequence_parallelism_mode": "split_gather",
|
||||||
|
"enable_flash_attention": True,
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 2,
|
||||||
|
"num_microbatches": 2,
|
||||||
"enable_all_optimization": True,
|
"enable_all_optimization": True,
|
||||||
"use_lazy_init": True,
|
"use_lazy_init": True,
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
|
@ -185,7 +216,16 @@ def run_gpt2_test(test_config):
|
||||||
loss_fn,
|
loss_fn,
|
||||||
_,
|
_,
|
||||||
) in sub_model_zoo.items():
|
) in sub_model_zoo.items():
|
||||||
|
|
||||||
|
if test_config.get("sequence_parallelism_mode", None) == "ring_attn" and name != "transformers_gpt_lm":
|
||||||
|
# Only wrote zigzag splitting for cross entropy loss
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed config: {test_config} for model {name}")
|
||||||
|
raise (e)
|
||||||
|
|
||||||
clear_layout_converter()
|
clear_layout_converter()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
@ -226,7 +266,11 @@ def run_gpt2_3d_test(test_config):
|
||||||
loss_fn,
|
loss_fn,
|
||||||
_,
|
_,
|
||||||
) in sub_model_zoo.items():
|
) in sub_model_zoo.items():
|
||||||
|
try:
|
||||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed config: {test_config} for model {name}")
|
||||||
|
raise (e)
|
||||||
|
|
||||||
clear_layout_converter()
|
clear_layout_converter()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
|
@ -165,7 +165,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
"zero_stage": 0,
|
"zero_stage": 0,
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
"initial_scale": 1,
|
"initial_scale": 1,
|
||||||
"inner_ring_size": 2,
|
|
||||||
},
|
},
|
||||||
# Ring Attention + PP
|
# Ring Attention + PP
|
||||||
{
|
{
|
||||||
|
@ -215,18 +214,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
"sequence_parallelism_mode": "all_to_all",
|
"sequence_parallelism_mode": "all_to_all",
|
||||||
"enable_all_optimization": True,
|
"enable_all_optimization": True,
|
||||||
"use_lazy_init": True,
|
"use_lazy_init": True,
|
||||||
"zero_stage": 0,
|
"zero_stage": 1,
|
||||||
"precision": "fp16",
|
|
||||||
"initial_scale": 1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"tp_size": 4,
|
|
||||||
"pp_size": 1,
|
|
||||||
"num_microbatches": 1,
|
|
||||||
"enable_sequence_parallelism": True,
|
|
||||||
"sequence_parallelism_mode": "split_gather",
|
|
||||||
"enable_flash_attention": True,
|
|
||||||
"use_lazy_init": True,
|
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
"initial_scale": 1,
|
"initial_scale": 1,
|
||||||
},
|
},
|
||||||
|
@ -294,6 +282,7 @@ def run_llama_test(test_config):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed config: {test_config}, model name: {name}")
|
print(f"Failed config: {test_config}, model name: {name}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
clear_layout_converter()
|
clear_layout_converter()
|
||||||
Randomizer.reset_index()
|
Randomizer.reset_index()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
|
@ -94,6 +94,32 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
@parameterize(
|
@parameterize(
|
||||||
"test_config",
|
"test_config",
|
||||||
[
|
[
|
||||||
|
{
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 2,
|
||||||
|
"sp_size": 2,
|
||||||
|
"num_microbatches": 2,
|
||||||
|
"enable_sequence_parallelism": True,
|
||||||
|
"sequence_parallelism_mode": "split_gather",
|
||||||
|
"enable_flash_attention": True,
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"zero_stage": 1,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
{ # Ulysess + Flash attention
|
||||||
|
"tp_size": 1,
|
||||||
|
"pp_size": 2,
|
||||||
|
"sp_size": 2,
|
||||||
|
"num_microbatches": 2,
|
||||||
|
"enable_sequence_parallelism": True,
|
||||||
|
"sequence_parallelism_mode": "all_to_all",
|
||||||
|
"enable_flash_attention": True,
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"zero_stage": 1,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"tp_size": 2,
|
"tp_size": 2,
|
||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
|
@ -135,32 +161,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
"initial_scale": 1,
|
"initial_scale": 1,
|
||||||
},
|
},
|
||||||
{ # Ulysess + Flash attention
|
|
||||||
"tp_size": 1,
|
|
||||||
"pp_size": 2,
|
|
||||||
"sp_size": 2,
|
|
||||||
"num_microbatches": 2,
|
|
||||||
"enable_sequence_parallelism": True,
|
|
||||||
"sequence_parallelism_mode": "all_to_all",
|
|
||||||
"enable_flash_attention": True,
|
|
||||||
"use_lazy_init": True,
|
|
||||||
"zero_stage": 1,
|
|
||||||
"precision": "fp16",
|
|
||||||
"initial_scale": 1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"tp_size": 2,
|
|
||||||
"pp_size": 2,
|
|
||||||
"sp_size": 2,
|
|
||||||
"num_microbatches": 2,
|
|
||||||
"enable_sequence_parallelism": True,
|
|
||||||
"sequence_parallelism_mode": "split_gather",
|
|
||||||
"enable_flash_attention": True,
|
|
||||||
"use_lazy_init": True,
|
|
||||||
"zero_stage": 1,
|
|
||||||
"precision": "fp16",
|
|
||||||
"initial_scale": 1,
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"tp_size": 2,
|
"tp_size": 2,
|
||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
|
|
Loading…
Reference in New Issue