mirror of https://github.com/hpcaitech/ColossalAI
Merge branch 'main' of https://github.com/hpcaitech/ColossalAI into rlhf_SimPO
commit
33f15203d3
|
@ -132,7 +132,7 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
|||
if use_bias:
|
||||
bias.view(bias.shape)
|
||||
|
||||
total_input = input
|
||||
total_input = input.contiguous()
|
||||
grad_input = grad_output.matmul(weight)
|
||||
grad_output = grad_output.contiguous()
|
||||
# Convert the tensor shapes to 2D for execution compatibility
|
||||
|
|
|
@ -11,7 +11,11 @@ from transformers.utils import logging
|
|||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig
|
||||
from colossalai.shardformer.layer import AttnMaskType, ColoAttention
|
||||
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
|
||||
from colossalai.shardformer.layer._operation import (
|
||||
all_to_all_comm,
|
||||
gather_forward_split_backward,
|
||||
split_forward_gather_backward,
|
||||
)
|
||||
|
||||
|
||||
def get_flash_core_attention_forward():
|
||||
|
@ -203,6 +207,13 @@ class ChatGLMPipelineForwards:
|
|||
dim=0,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
)
|
||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states,
|
||||
dim=0,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
grad_scale=1 / shard_config.sequence_parallel_size,
|
||||
)
|
||||
for idx in range(start_idx, end_idx):
|
||||
layer = self.encoder._get_layer(idx)
|
||||
if output_hidden_states:
|
||||
|
@ -235,6 +246,13 @@ class ChatGLMPipelineForwards:
|
|||
dim=0,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
)
|
||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states,
|
||||
dim=0,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
grad_scale=shard_config.sequence_parallel_size,
|
||||
)
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
if stage_manager.is_last_stage():
|
||||
|
@ -329,7 +347,9 @@ class ChatGLMPipelineForwards:
|
|||
return transformer_outputs
|
||||
|
||||
|
||||
def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||
def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode, sp_size, sp_group):
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
|
@ -381,13 +401,27 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
|
||||
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
|
||||
|
||||
if sp_mode in ["all_to_all"] and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with sp mode `{sp_mode}`. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
# Run encoder.
|
||||
# [seq_len, batch_size, hidden_size] -> [seq_len/TP_size, batch_size, hidden_size]
|
||||
inputs_embeds = split_forward_gather_backward(
|
||||
inputs_embeds,
|
||||
dim=0,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
)
|
||||
if sp_mode in ["split_gather"]:
|
||||
inputs_embeds = split_forward_gather_backward(
|
||||
inputs_embeds,
|
||||
dim=0,
|
||||
process_group=sp_group,
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
inputs_embeds = split_forward_gather_backward(
|
||||
inputs_embeds,
|
||||
dim=0,
|
||||
process_group=sp_group,
|
||||
grad_scale=1 / sp_size,
|
||||
)
|
||||
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
|
||||
inputs_embeds,
|
||||
full_attention_mask,
|
||||
|
@ -397,11 +431,19 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states,
|
||||
dim=0,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
)
|
||||
if sp_mode in ["split_gather"]:
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states,
|
||||
dim=0,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states,
|
||||
dim=0,
|
||||
process_group=sp_group,
|
||||
grad_scale=sp_size,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
|
@ -423,3 +465,158 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||
)
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_chatglm_sequence_parallel_attention_forward(shard_config: ShardConfig, sp_mode, sp_size, sp_group):
|
||||
from .chatglm2_6b.modeling_chatglm import apply_rotary_pos_emb, split_tensor_along_last_dim
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
rotary_pos_emb,
|
||||
kv_cache=None,
|
||||
use_cache=True,
|
||||
):
|
||||
if sp_mode is not None:
|
||||
assert sp_mode in ["all_to_all", "split_gather"], "Invalid sp_mode"
|
||||
assert (sp_size is not None) and (
|
||||
sp_group is not None
|
||||
), "Must specify sp_size and sp_group for sequence parallel"
|
||||
|
||||
mixed_x_layer = self.query_key_value(hidden_states)
|
||||
if self.multi_query_attention:
|
||||
(query_layer, key_layer, value_layer) = mixed_x_layer.split(
|
||||
[
|
||||
self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
|
||||
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
|
||||
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
query_layer = query_layer.view(
|
||||
query_layer.size()[:-1]
|
||||
+ (
|
||||
self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head,
|
||||
)
|
||||
)
|
||||
key_layer = key_layer.view(
|
||||
key_layer.size()[:-1]
|
||||
+ (
|
||||
self.num_multi_query_groups_per_partition,
|
||||
self.hidden_size_per_attention_head,
|
||||
)
|
||||
)
|
||||
value_layer = value_layer.view(
|
||||
value_layer.size()[:-1]
|
||||
+ (
|
||||
self.num_multi_query_groups_per_partition,
|
||||
self.hidden_size_per_attention_head,
|
||||
)
|
||||
)
|
||||
else:
|
||||
new_tensor_shape = mixed_x_layer.size()[:-1] + (
|
||||
self.num_attention_heads_per_partition,
|
||||
3 * self.hidden_size_per_attention_head,
|
||||
)
|
||||
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
|
||||
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
|
||||
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
|
||||
|
||||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
sq, bs, _, _ = value_layer.size()
|
||||
|
||||
query_layer = query_layer.reshape(sq, bs, -1)
|
||||
key_layer = key_layer.reshape(sq, bs, -1)
|
||||
value_layer = value_layer.reshape(sq, bs, -1)
|
||||
|
||||
query_layer = all_to_all_comm(query_layer, sp_group, gather_dim=0)
|
||||
key_layer = all_to_all_comm(key_layer, sp_group, gather_dim=0)
|
||||
value_layer = all_to_all_comm(value_layer, sp_group, gather_dim=0)
|
||||
|
||||
query_layer = query_layer.view(
|
||||
sq * sp_size,
|
||||
bs,
|
||||
self.num_attention_heads_per_partition // sp_size,
|
||||
self.hidden_size_per_attention_head,
|
||||
).contiguous()
|
||||
|
||||
key_layer = key_layer.view(
|
||||
sq * sp_size,
|
||||
bs,
|
||||
self.num_attention_heads_per_partition // sp_size,
|
||||
self.hidden_size_per_attention_head,
|
||||
).contiguous()
|
||||
|
||||
value_layer = value_layer.view(
|
||||
sq * sp_size,
|
||||
bs,
|
||||
self.num_attention_heads_per_partition // sp_size,
|
||||
self.hidden_size_per_attention_head,
|
||||
).contiguous()
|
||||
|
||||
# apply relative positional encoding (rotary embedding)
|
||||
if rotary_pos_emb is not None:
|
||||
query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
|
||||
key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
|
||||
|
||||
# adjust key and value for inference
|
||||
if kv_cache is not None:
|
||||
cache_k, cache_v = kv_cache
|
||||
key_layer = torch.cat((cache_k, key_layer), dim=0)
|
||||
value_layer = torch.cat((cache_v, value_layer), dim=0)
|
||||
if use_cache:
|
||||
kv_cache = (key_layer, value_layer)
|
||||
else:
|
||||
kv_cache = None
|
||||
|
||||
if self.multi_query_attention:
|
||||
key_layer = key_layer.unsqueeze(-2)
|
||||
key_layer = key_layer.expand(
|
||||
-1,
|
||||
-1,
|
||||
-1,
|
||||
self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition,
|
||||
-1,
|
||||
)
|
||||
key_layer = key_layer.contiguous().view(
|
||||
key_layer.size()[:2]
|
||||
+ (
|
||||
self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head,
|
||||
)
|
||||
)
|
||||
value_layer = value_layer.unsqueeze(-2)
|
||||
value_layer = value_layer.expand(
|
||||
-1,
|
||||
-1,
|
||||
-1,
|
||||
self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition,
|
||||
-1,
|
||||
)
|
||||
value_layer = value_layer.contiguous().view(
|
||||
value_layer.size()[:2]
|
||||
+ (
|
||||
self.num_attention_heads_per_partition // sp_size,
|
||||
self.hidden_size_per_attention_head,
|
||||
)
|
||||
)
|
||||
|
||||
# ==================================
|
||||
# core attention computation
|
||||
# ==================================
|
||||
|
||||
context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
|
||||
if sp_mode == "all_to_all":
|
||||
context_layer = all_to_all_comm(context_layer, sp_group, gather_dim=2, scatter_dim=0)
|
||||
|
||||
# =================
|
||||
# Output. [sq, b, h]
|
||||
# =================
|
||||
output = self.dense(context_layer)
|
||||
|
||||
return output, kv_cache
|
||||
|
||||
return forward
|
||||
|
|
|
@ -134,6 +134,21 @@ class CommandPipelineForwards:
|
|||
)
|
||||
use_cache = False
|
||||
|
||||
if shard_config and shard_config.enable_sequence_parallelism:
|
||||
if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
)
|
||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
grad_scale=1 / shard_config.sequence_parallel_size,
|
||||
)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
@ -190,6 +205,21 @@ class CommandPipelineForwards:
|
|||
if stage_manager.is_last_stage():
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if shard_config and shard_config.enable_sequence_parallelism:
|
||||
if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
)
|
||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
grad_scale=shard_config.sequence_parallel_size,
|
||||
)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
|
@ -30,6 +31,11 @@ except ImportError:
|
|||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.layer._operation import (
|
||||
all_to_all_comm,
|
||||
gather_forward_split_backward,
|
||||
split_forward_gather_backward,
|
||||
)
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from ..layer import ColoAttention, dist_cross_entropy
|
||||
|
@ -162,6 +168,21 @@ class Qwen2PipelineForwards:
|
|||
sliding_window=self.config.sliding_window,
|
||||
)
|
||||
|
||||
if shard_config and shard_config.enable_sequence_parallelism:
|
||||
if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
)
|
||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
grad_scale=1 / shard_config.sequence_parallel_size,
|
||||
)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
@ -218,6 +239,20 @@ class Qwen2PipelineForwards:
|
|||
if stage_manager.is_last_stage():
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if shard_config and shard_config.enable_sequence_parallelism:
|
||||
if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
)
|
||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
grad_scale=shard_config.sequence_parallel_size,
|
||||
)
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
@ -453,7 +488,7 @@ class Qwen2PipelineForwards:
|
|||
return {"hidden_states": hidden_states}
|
||||
|
||||
|
||||
def get_qwen2_flash_attention_forward(shard_config: ShardConfig):
|
||||
def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
|
||||
def forward(
|
||||
self: Qwen2Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
|
@ -464,12 +499,28 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig):
|
|||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if sp_mode is not None:
|
||||
assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode"
|
||||
assert (sp_size is not None) and (
|
||||
sp_group is not None
|
||||
), "Must specify sp_size and sp_group for sequence parallel"
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
# sp: modify sp_len when sequence parallel mode is ring
|
||||
if sp_mode in ["split_gather", "ring"]:
|
||||
q_len *= sp_size
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
query_states = all_to_all_comm(query_states, sp_group)
|
||||
key_states = all_to_all_comm(key_states, sp_group)
|
||||
value_states = all_to_all_comm(value_states, sp_group)
|
||||
bsz, q_len, _ = query_states.size()
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
@ -522,10 +573,41 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig):
|
|||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
|
||||
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
|
||||
if shard_config.enable_flash_attention:
|
||||
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
|
||||
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
|
||||
else:
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
if sp_mode == "all_to_all":
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
|
||||
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2)
|
||||
else:
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
@ -533,9 +615,8 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig):
|
|||
return forward
|
||||
|
||||
|
||||
def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig):
|
||||
def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
|
||||
logger = logging.get_logger(__name__)
|
||||
assert shard_config.enable_flash_attention, "Flash Attention is not enabled."
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -585,17 +666,26 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig):
|
|||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# in this case, attention_mask is a dict rather than a tensor
|
||||
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
|
||||
attention_mask = ColoAttention.prepare_attn_kwargs(
|
||||
mask_shape,
|
||||
hidden_states.dtype,
|
||||
hidden_states.device,
|
||||
q_padding_mask=attention_mask,
|
||||
is_causal=True,
|
||||
)
|
||||
if shard_config.enable_flash_attention:
|
||||
# in this case, attention_mask is a dict rather than a tensor
|
||||
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
|
||||
attention_mask = ColoAttention.prepare_attn_kwargs(
|
||||
mask_shape,
|
||||
hidden_states.dtype,
|
||||
hidden_states.device,
|
||||
q_padding_mask=attention_mask,
|
||||
is_causal=True,
|
||||
)
|
||||
else:
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
sliding_window=self.config.sliding_window,
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
|
@ -607,6 +697,11 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig):
|
|||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
if sp_mode in ["ring", "split_gather"]:
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
|
||||
elif sp_mode == "all_to_all":
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
@ -641,6 +736,11 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig):
|
|||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if sp_mode == "ring" or sp_mode == "split_gather":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
|
|
@ -9,6 +9,7 @@ import colossalai.shardformer.layer as col_nn
|
|||
from colossalai.shardformer.modeling.chatglm2 import ChatGLMPipelineForwards
|
||||
|
||||
from ..modeling.chatglm2 import (
|
||||
get_chatglm_sequence_parallel_attention_forward,
|
||||
get_chatglm_sequence_parallel_forward_fn,
|
||||
get_flash_core_attention_forward,
|
||||
get_jit_fused_glm_block_forward,
|
||||
|
@ -58,14 +59,29 @@ class ChatGLMPolicy(Policy):
|
|||
norm_cls = col_nn.LayerNorm
|
||||
|
||||
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
||||
assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for ChatGLM2"
|
||||
sp_size = self.shard_config.sequence_parallel_size or None
|
||||
sp_group = self.shard_config.sequence_parallel_process_group or None
|
||||
|
||||
if sp_mode == "ring":
|
||||
warnings.warn(
|
||||
f"For ChatGLM2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather"
|
||||
)
|
||||
sp_mode = "split_gather"
|
||||
overlap = self.shard_config.enable_sequence_overlap
|
||||
sp_partial_derived = sp_mode == "split_gather"
|
||||
sp_partial_derived = sp_mode in ["split_gather"]
|
||||
|
||||
if sp_mode == "all_to_all":
|
||||
decoder_attribute_replacement = {
|
||||
"num_heads": self.model.config.num_attention_heads // sp_size,
|
||||
"hidden_size_per_partition": self.model.config.kv_channels
|
||||
* self.model.config.num_attention_heads
|
||||
// sp_size,
|
||||
}
|
||||
if getattr(self.model.config, "num_key_value_heads", False):
|
||||
decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size
|
||||
policy["CoreAttention"] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
)
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
assert (
|
||||
|
@ -179,12 +195,26 @@ class ChatGLMPolicy(Policy):
|
|||
)
|
||||
|
||||
# use sequence parallel
|
||||
if sp_mode == "split_gather":
|
||||
if self.shard_config.enable_sequence_parallelism:
|
||||
self.append_or_create_method_replacement(
|
||||
description={"forward": get_chatglm_sequence_parallel_forward_fn(self.shard_config)},
|
||||
description={
|
||||
"forward": get_chatglm_sequence_parallel_attention_forward(
|
||||
self.shard_config, sp_mode, sp_size, sp_group
|
||||
),
|
||||
},
|
||||
policy=policy,
|
||||
target_key="ChatGLMModel",
|
||||
target_key="SelfAttention",
|
||||
)
|
||||
if self.pipeline_stage_manager is None:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_chatglm_sequence_parallel_forward_fn(
|
||||
self.shard_config, sp_mode, sp_size, sp_group
|
||||
)
|
||||
},
|
||||
policy=policy,
|
||||
target_key="ChatGLMModel",
|
||||
)
|
||||
|
||||
# use jit fused operator
|
||||
if self.shard_config.enable_jit_fused:
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import warnings
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, List, Union
|
||||
|
||||
|
@ -66,13 +65,6 @@ class CommandPolicy(Policy):
|
|||
else:
|
||||
norm_cls = LayerNorm
|
||||
|
||||
if self.pipeline_stage_manager is not None:
|
||||
self.shard_config.enable_sequence_parallelism = False
|
||||
self.shard_config.enable_sequence_overlap = False
|
||||
self.shard_config.sequence_parallelism_mode = None
|
||||
warnings.warn(
|
||||
f"For Command, sequence parallelism is currently not compatible with pipeline parallelism, set to be False"
|
||||
)
|
||||
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
||||
sp_size = self.shard_config.sequence_parallel_size or None
|
||||
sp_group = self.shard_config.sequence_parallel_process_group or None
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import warnings
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, List, Union
|
||||
|
||||
|
@ -82,9 +81,20 @@ class Qwen2Policy(Policy):
|
|||
embedding_cls = PaddingEmbedding
|
||||
norm_cls = FusedRMSNorm if self.shard_config.enable_fused_normalization else RMSNorm
|
||||
|
||||
if self.shard_config.enable_sequence_parallelism:
|
||||
self.shard_config.enable_sequence_parallelism = False
|
||||
warnings.warn("Qwen2 doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
||||
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
||||
sp_size = self.shard_config.sequence_parallel_size or None
|
||||
sp_group = self.shard_config.sequence_parallel_process_group or None
|
||||
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
||||
if sp_mode == "all_to_all":
|
||||
decoder_attribute_replacement = {
|
||||
"num_heads": self.model.config.num_attention_heads // sp_size,
|
||||
}
|
||||
if getattr(self.model.config, "num_key_value_heads", False):
|
||||
decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size
|
||||
|
||||
policy[attn_cls] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
)
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
assert (
|
||||
|
@ -109,30 +119,37 @@ class Qwen2Policy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.gate_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.up_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.down_proj",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
@ -154,10 +171,12 @@ class Qwen2Policy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="input_layernorm",
|
||||
target_module=norm_cls,
|
||||
kwargs={"sp_partial_derived": sp_partial_derived},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="post_attention_layernorm",
|
||||
target_module=norm_cls,
|
||||
kwargs={"sp_partial_derived": sp_partial_derived},
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
|
@ -168,16 +187,16 @@ class Qwen2Policy(Policy):
|
|||
description=SubModuleReplacementDescription(
|
||||
suffix="norm",
|
||||
target_module=norm_cls,
|
||||
kwargs={"sp_partial_derived": sp_partial_derived},
|
||||
),
|
||||
policy=policy,
|
||||
target_key=Qwen2Model,
|
||||
)
|
||||
|
||||
# use flash attention
|
||||
if self.shard_config.enable_flash_attention:
|
||||
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_qwen2_flash_attention_forward(self.shard_config),
|
||||
"forward": get_qwen2_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=attn_cls,
|
||||
|
@ -186,7 +205,9 @@ class Qwen2Policy(Policy):
|
|||
# replace qwen2 model forward method
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_qwen2_model_forward_for_flash_attn(self.shard_config),
|
||||
"forward": get_qwen2_model_forward_for_flash_attn(
|
||||
self.shard_config, sp_mode, sp_size, sp_group
|
||||
),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=Qwen2Model,
|
||||
|
|
|
@ -136,6 +136,44 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
{ # Ulysess + Flash attention
|
||||
"tp_size": 1,
|
||||
"pp_size": 2,
|
||||
"sp_size": 2,
|
||||
"num_microbatches": 2,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "all_to_all",
|
||||
"enable_flash_attention": True,
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 1,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"sp_size": 2,
|
||||
"num_microbatches": 2,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "split_gather",
|
||||
"enable_flash_attention": True,
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 1,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 1,
|
||||
"pp_size": 1,
|
||||
"sp_size": 2,
|
||||
"num_microbatches": 1,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "all_to_all",
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 1,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 4,
|
||||
"pp_size": 1,
|
||||
|
|
|
@ -58,6 +58,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
# Check the grad when using ZeRO-1 and ZeRO-2
|
||||
if (
|
||||
booster.plugin.zero_stage in [1, 2]
|
||||
and booster.plugin.shard_config.pipeline_stage_manager is None
|
||||
and booster.plugin.shard_config.enable_sequence_parallelism
|
||||
and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all"
|
||||
):
|
||||
|
@ -154,6 +155,45 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
{ # Ulysess + Flash attention
|
||||
"tp_size": 1,
|
||||
"pp_size": 2,
|
||||
"sp_size": 2,
|
||||
"num_microbatches": 2,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "all_to_all",
|
||||
"enable_flash_attention": True,
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 1,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"sp_size": 2,
|
||||
"num_microbatches": 2,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "split_gather",
|
||||
"enable_flash_attention": True,
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 1,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"sp_size": 2,
|
||||
"num_microbatches": 2,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "ring",
|
||||
"enable_flash_attention": True,
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 1,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 1,
|
||||
|
|
|
@ -180,6 +180,68 @@ def run_qwen2_test(test_config):
|
|||
"zero_stage": 1,
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{ # Ulysess + Flash attention
|
||||
"tp_size": 1,
|
||||
"pp_size": 2,
|
||||
"sp_size": 2,
|
||||
"num_microbatches": 2,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "all_to_all",
|
||||
"enable_flash_attention": True,
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 1,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"sp_size": 2,
|
||||
"num_microbatches": 2,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "split_gather",
|
||||
"enable_flash_attention": True,
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 1,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"sp_size": 2,
|
||||
"num_microbatches": 2,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "ring",
|
||||
"enable_flash_attention": True,
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 1,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 1,
|
||||
"pp_size": 1,
|
||||
"sp_size": 2,
|
||||
"num_microbatches": 1,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "all_to_all",
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 1,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 4,
|
||||
"pp_size": 1,
|
||||
"num_microbatches": 1,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "split_gather",
|
||||
"enable_flash_attention": False,
|
||||
"use_lazy_init": True,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
|
|
Loading…
Reference in New Issue