Merge branch 'main' of https://github.com/hpcaitech/ColossalAI into rlhf_SimPO

pull/5850/head
YeAnbang 2024-07-10 10:39:34 +00:00
commit 33f15203d3
10 changed files with 560 additions and 50 deletions

View File

@ -132,7 +132,7 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
if use_bias: if use_bias:
bias.view(bias.shape) bias.view(bias.shape)
total_input = input total_input = input.contiguous()
grad_input = grad_output.matmul(weight) grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous() grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility # Convert the tensor shapes to 2D for execution compatibility

View File

@ -11,7 +11,11 @@ from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig from colossalai.shardformer import ShardConfig
from colossalai.shardformer.layer import AttnMaskType, ColoAttention from colossalai.shardformer.layer import AttnMaskType, ColoAttention
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.layer._operation import (
all_to_all_comm,
gather_forward_split_backward,
split_forward_gather_backward,
)
def get_flash_core_attention_forward(): def get_flash_core_attention_forward():
@ -203,6 +207,13 @@ class ChatGLMPipelineForwards:
dim=0, dim=0,
process_group=shard_config.tensor_parallel_process_group, process_group=shard_config.tensor_parallel_process_group,
) )
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = split_forward_gather_backward(
hidden_states,
dim=0,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=1 / shard_config.sequence_parallel_size,
)
for idx in range(start_idx, end_idx): for idx in range(start_idx, end_idx):
layer = self.encoder._get_layer(idx) layer = self.encoder._get_layer(idx)
if output_hidden_states: if output_hidden_states:
@ -235,6 +246,13 @@ class ChatGLMPipelineForwards:
dim=0, dim=0,
process_group=shard_config.tensor_parallel_process_group, process_group=shard_config.tensor_parallel_process_group,
) )
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = gather_forward_split_backward(
hidden_states,
dim=0,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=shard_config.sequence_parallel_size,
)
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
@ -329,7 +347,9 @@ class ChatGLMPipelineForwards:
return transformer_outputs return transformer_outputs
def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig): def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode, sp_size, sp_group):
logger = logging.get_logger(__name__)
def forward( def forward(
self, self,
input_ids, input_ids,
@ -381,13 +401,27 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig):
rotary_pos_emb = rotary_pos_emb[None, :seq_length] rotary_pos_emb = rotary_pos_emb[None, :seq_length]
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
if sp_mode in ["all_to_all"] and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with sp mode `{sp_mode}`. Setting `use_cache=False`..."
)
use_cache = False
# Run encoder. # Run encoder.
# [seq_len, batch_size, hidden_size] -> [seq_len/TP_size, batch_size, hidden_size] # [seq_len, batch_size, hidden_size] -> [seq_len/TP_size, batch_size, hidden_size]
inputs_embeds = split_forward_gather_backward( if sp_mode in ["split_gather"]:
inputs_embeds, inputs_embeds = split_forward_gather_backward(
dim=0, inputs_embeds,
process_group=shard_config.tensor_parallel_process_group, dim=0,
) process_group=sp_group,
)
elif sp_mode == "all_to_all":
inputs_embeds = split_forward_gather_backward(
inputs_embeds,
dim=0,
process_group=sp_group,
grad_scale=1 / sp_size,
)
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
inputs_embeds, inputs_embeds,
full_attention_mask, full_attention_mask,
@ -397,11 +431,19 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
) )
hidden_states = gather_forward_split_backward( if sp_mode in ["split_gather"]:
hidden_states, hidden_states = gather_forward_split_backward(
dim=0, hidden_states,
process_group=shard_config.tensor_parallel_process_group, dim=0,
) process_group=shard_config.tensor_parallel_process_group,
)
elif sp_mode == "all_to_all":
hidden_states = gather_forward_split_backward(
hidden_states,
dim=0,
process_group=sp_group,
grad_scale=sp_size,
)
if not return_dict: if not return_dict:
return tuple( return tuple(
@ -423,3 +465,158 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig):
) )
return forward return forward
def get_chatglm_sequence_parallel_attention_forward(shard_config: ShardConfig, sp_mode, sp_size, sp_group):
from .chatglm2_6b.modeling_chatglm import apply_rotary_pos_emb, split_tensor_along_last_dim
def forward(
self,
hidden_states,
attention_mask,
rotary_pos_emb,
kv_cache=None,
use_cache=True,
):
if sp_mode is not None:
assert sp_mode in ["all_to_all", "split_gather"], "Invalid sp_mode"
assert (sp_size is not None) and (
sp_group is not None
), "Must specify sp_size and sp_group for sequence parallel"
mixed_x_layer = self.query_key_value(hidden_states)
if self.multi_query_attention:
(query_layer, key_layer, value_layer) = mixed_x_layer.split(
[
self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
],
dim=-1,
)
query_layer = query_layer.view(
query_layer.size()[:-1]
+ (
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)
)
key_layer = key_layer.view(
key_layer.size()[:-1]
+ (
self.num_multi_query_groups_per_partition,
self.hidden_size_per_attention_head,
)
)
value_layer = value_layer.view(
value_layer.size()[:-1]
+ (
self.num_multi_query_groups_per_partition,
self.hidden_size_per_attention_head,
)
)
else:
new_tensor_shape = mixed_x_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
sq, bs, _, _ = value_layer.size()
query_layer = query_layer.reshape(sq, bs, -1)
key_layer = key_layer.reshape(sq, bs, -1)
value_layer = value_layer.reshape(sq, bs, -1)
query_layer = all_to_all_comm(query_layer, sp_group, gather_dim=0)
key_layer = all_to_all_comm(key_layer, sp_group, gather_dim=0)
value_layer = all_to_all_comm(value_layer, sp_group, gather_dim=0)
query_layer = query_layer.view(
sq * sp_size,
bs,
self.num_attention_heads_per_partition // sp_size,
self.hidden_size_per_attention_head,
).contiguous()
key_layer = key_layer.view(
sq * sp_size,
bs,
self.num_attention_heads_per_partition // sp_size,
self.hidden_size_per_attention_head,
).contiguous()
value_layer = value_layer.view(
sq * sp_size,
bs,
self.num_attention_heads_per_partition // sp_size,
self.hidden_size_per_attention_head,
).contiguous()
# apply relative positional encoding (rotary embedding)
if rotary_pos_emb is not None:
query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
# adjust key and value for inference
if kv_cache is not None:
cache_k, cache_v = kv_cache
key_layer = torch.cat((cache_k, key_layer), dim=0)
value_layer = torch.cat((cache_v, value_layer), dim=0)
if use_cache:
kv_cache = (key_layer, value_layer)
else:
kv_cache = None
if self.multi_query_attention:
key_layer = key_layer.unsqueeze(-2)
key_layer = key_layer.expand(
-1,
-1,
-1,
self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition,
-1,
)
key_layer = key_layer.contiguous().view(
key_layer.size()[:2]
+ (
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)
)
value_layer = value_layer.unsqueeze(-2)
value_layer = value_layer.expand(
-1,
-1,
-1,
self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition,
-1,
)
value_layer = value_layer.contiguous().view(
value_layer.size()[:2]
+ (
self.num_attention_heads_per_partition // sp_size,
self.hidden_size_per_attention_head,
)
)
# ==================================
# core attention computation
# ==================================
context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
if sp_mode == "all_to_all":
context_layer = all_to_all_comm(context_layer, sp_group, gather_dim=2, scatter_dim=0)
# =================
# Output. [sq, b, h]
# =================
output = self.dense(context_layer)
return output, kv_cache
return forward

View File

@ -134,6 +134,21 @@ class CommandPipelineForwards:
) )
use_cache = False use_cache = False
if shard_config and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
hidden_states = split_forward_gather_backward(
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = split_forward_gather_backward(
hidden_states,
dim=1,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=1 / shard_config.sequence_parallel_size,
)
# decoder layers # decoder layers
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None all_self_attns = () if output_attentions else None
@ -190,6 +205,21 @@ class CommandPipelineForwards:
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
if shard_config and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
hidden_states = gather_forward_split_backward(
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = gather_forward_split_backward(
hidden_states,
dim=1,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=shard_config.sequence_parallel_size,
)
# add hidden states from the last decoder layer # add hidden states from the last decoder layer
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)

View File

@ -1,6 +1,7 @@
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_outputs import ( from transformers.modeling_outputs import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
@ -30,6 +31,11 @@ except ImportError:
from transformers.utils import logging from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer._operation import (
all_to_all_comm,
gather_forward_split_backward,
split_forward_gather_backward,
)
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig
from ..layer import ColoAttention, dist_cross_entropy from ..layer import ColoAttention, dist_cross_entropy
@ -162,6 +168,21 @@ class Qwen2PipelineForwards:
sliding_window=self.config.sliding_window, sliding_window=self.config.sliding_window,
) )
if shard_config and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
hidden_states = split_forward_gather_backward(
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = split_forward_gather_backward(
hidden_states,
dim=1,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=1 / shard_config.sequence_parallel_size,
)
# decoder layers # decoder layers
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None all_self_attns = () if output_attentions else None
@ -218,6 +239,20 @@ class Qwen2PipelineForwards:
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
if shard_config and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
hidden_states = gather_forward_split_backward(
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = gather_forward_split_backward(
hidden_states,
dim=1,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=shard_config.sequence_parallel_size,
)
# add hidden states from the last decoder layer # add hidden states from the last decoder layer
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
@ -453,7 +488,7 @@ class Qwen2PipelineForwards:
return {"hidden_states": hidden_states} return {"hidden_states": hidden_states}
def get_qwen2_flash_attention_forward(shard_config: ShardConfig): def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
def forward( def forward(
self: Qwen2Attention, self: Qwen2Attention,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -464,12 +499,28 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig):
use_cache: bool = False, use_cache: bool = False,
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if sp_mode is not None:
assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode"
assert (sp_size is not None) and (
sp_group is not None
), "Must specify sp_size and sp_group for sequence parallel"
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
# sp: modify sp_len when sequence parallel mode is ring
if sp_mode in ["split_gather", "ring"]:
q_len *= sp_size
query_states = self.q_proj(hidden_states) query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states) key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states) value_states = self.v_proj(hidden_states)
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
query_states = all_to_all_comm(query_states, sp_group)
key_states = all_to_all_comm(key_states, sp_group)
value_states = all_to_all_comm(value_states, sp_group)
bsz, q_len, _ = query_states.size()
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
@ -522,10 +573,41 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig):
key_states = repeat_kv(key_states, self.num_key_value_groups) key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." if shard_config.enable_flash_attention:
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
else:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) if sp_mode == "all_to_all":
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2)
else:
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value return attn_output, None, past_key_value
@ -533,9 +615,8 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig):
return forward return forward
def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig): def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
assert shard_config.enable_flash_attention, "Flash Attention is not enabled."
def forward( def forward(
self, self,
@ -585,17 +666,26 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig):
# embed positions # embed positions
hidden_states = inputs_embeds hidden_states = inputs_embeds
# in this case, attention_mask is a dict rather than a tensor if shard_config.enable_flash_attention:
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) # in this case, attention_mask is a dict rather than a tensor
attention_mask = ColoAttention.prepare_attn_kwargs( mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
mask_shape, attention_mask = ColoAttention.prepare_attn_kwargs(
hidden_states.dtype, mask_shape,
hidden_states.device, hidden_states.dtype,
q_padding_mask=attention_mask, hidden_states.device,
is_causal=True, q_padding_mask=attention_mask,
) is_causal=True,
)
else:
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window=self.config.sliding_window,
)
if self.gradient_checkpointing and self.training: if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training:
if use_cache: if use_cache:
logger.warning_once( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
@ -607,6 +697,11 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig):
all_self_attns = () if output_attentions else None all_self_attns = () if output_attentions else None
next_decoder_cache = None next_decoder_cache = None
if sp_mode in ["ring", "split_gather"]:
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
elif sp_mode == "all_to_all":
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
for decoder_layer in self.layers: for decoder_layer in self.layers:
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
@ -641,6 +736,11 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig):
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
if sp_mode == "ring" or sp_mode == "split_gather":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
elif sp_mode == "all_to_all":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
# add hidden states from the last decoder layer # add hidden states from the last decoder layer
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)

View File

@ -9,6 +9,7 @@ import colossalai.shardformer.layer as col_nn
from colossalai.shardformer.modeling.chatglm2 import ChatGLMPipelineForwards from colossalai.shardformer.modeling.chatglm2 import ChatGLMPipelineForwards
from ..modeling.chatglm2 import ( from ..modeling.chatglm2 import (
get_chatglm_sequence_parallel_attention_forward,
get_chatglm_sequence_parallel_forward_fn, get_chatglm_sequence_parallel_forward_fn,
get_flash_core_attention_forward, get_flash_core_attention_forward,
get_jit_fused_glm_block_forward, get_jit_fused_glm_block_forward,
@ -58,14 +59,29 @@ class ChatGLMPolicy(Policy):
norm_cls = col_nn.LayerNorm norm_cls = col_nn.LayerNorm
sp_mode = self.shard_config.sequence_parallelism_mode or None sp_mode = self.shard_config.sequence_parallelism_mode or None
assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for ChatGLM2" sp_size = self.shard_config.sequence_parallel_size or None
sp_group = self.shard_config.sequence_parallel_process_group or None
if sp_mode == "ring": if sp_mode == "ring":
warnings.warn( warnings.warn(
f"For ChatGLM2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather" f"For ChatGLM2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather"
) )
sp_mode = "split_gather" sp_mode = "split_gather"
overlap = self.shard_config.enable_sequence_overlap overlap = self.shard_config.enable_sequence_overlap
sp_partial_derived = sp_mode == "split_gather" sp_partial_derived = sp_mode in ["split_gather"]
if sp_mode == "all_to_all":
decoder_attribute_replacement = {
"num_heads": self.model.config.num_attention_heads // sp_size,
"hidden_size_per_partition": self.model.config.kv_channels
* self.model.config.num_attention_heads
// sp_size,
}
if getattr(self.model.config, "num_key_value_heads", False):
decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size
policy["CoreAttention"] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
)
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
assert ( assert (
@ -179,12 +195,26 @@ class ChatGLMPolicy(Policy):
) )
# use sequence parallel # use sequence parallel
if sp_mode == "split_gather": if self.shard_config.enable_sequence_parallelism:
self.append_or_create_method_replacement( self.append_or_create_method_replacement(
description={"forward": get_chatglm_sequence_parallel_forward_fn(self.shard_config)}, description={
"forward": get_chatglm_sequence_parallel_attention_forward(
self.shard_config, sp_mode, sp_size, sp_group
),
},
policy=policy, policy=policy,
target_key="ChatGLMModel", target_key="SelfAttention",
) )
if self.pipeline_stage_manager is None:
self.append_or_create_method_replacement(
description={
"forward": get_chatglm_sequence_parallel_forward_fn(
self.shard_config, sp_mode, sp_size, sp_group
)
},
policy=policy,
target_key="ChatGLMModel",
)
# use jit fused operator # use jit fused operator
if self.shard_config.enable_jit_fused: if self.shard_config.enable_jit_fused:

View File

@ -1,4 +1,3 @@
import warnings
from functools import partial from functools import partial
from typing import Callable, Dict, List, Union from typing import Callable, Dict, List, Union
@ -66,13 +65,6 @@ class CommandPolicy(Policy):
else: else:
norm_cls = LayerNorm norm_cls = LayerNorm
if self.pipeline_stage_manager is not None:
self.shard_config.enable_sequence_parallelism = False
self.shard_config.enable_sequence_overlap = False
self.shard_config.sequence_parallelism_mode = None
warnings.warn(
f"For Command, sequence parallelism is currently not compatible with pipeline parallelism, set to be False"
)
sp_mode = self.shard_config.sequence_parallelism_mode or None sp_mode = self.shard_config.sequence_parallelism_mode or None
sp_size = self.shard_config.sequence_parallel_size or None sp_size = self.shard_config.sequence_parallel_size or None
sp_group = self.shard_config.sequence_parallel_process_group or None sp_group = self.shard_config.sequence_parallel_process_group or None

View File

@ -1,4 +1,3 @@
import warnings
from functools import partial from functools import partial
from typing import Callable, Dict, List, Union from typing import Callable, Dict, List, Union
@ -82,9 +81,20 @@ class Qwen2Policy(Policy):
embedding_cls = PaddingEmbedding embedding_cls = PaddingEmbedding
norm_cls = FusedRMSNorm if self.shard_config.enable_fused_normalization else RMSNorm norm_cls = FusedRMSNorm if self.shard_config.enable_fused_normalization else RMSNorm
if self.shard_config.enable_sequence_parallelism: sp_mode = self.shard_config.sequence_parallelism_mode or None
self.shard_config.enable_sequence_parallelism = False sp_size = self.shard_config.sequence_parallel_size or None
warnings.warn("Qwen2 doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") sp_group = self.shard_config.sequence_parallel_process_group or None
sp_partial_derived = sp_mode in ["split_gather", "ring"]
if sp_mode == "all_to_all":
decoder_attribute_replacement = {
"num_heads": self.model.config.num_attention_heads // sp_size,
}
if getattr(self.model.config, "num_key_value_heads", False):
decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size
policy[attn_cls] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
)
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
assert ( assert (
@ -109,30 +119,37 @@ class Qwen2Policy(Policy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn.q_proj", suffix="self_attn.q_proj",
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode),
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn.k_proj", suffix="self_attn.k_proj",
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode),
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn.v_proj", suffix="self_attn.v_proj",
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode),
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn.o_proj", suffix="self_attn.o_proj",
target_module=Linear1D_Row, target_module=Linear1D_Row,
kwargs=dict(seq_parallel_mode=sp_mode),
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="mlp.gate_proj", suffix="mlp.gate_proj",
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode),
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="mlp.up_proj", suffix="mlp.up_proj",
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode),
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="mlp.down_proj", suffix="mlp.down_proj",
target_module=Linear1D_Row, target_module=Linear1D_Row,
kwargs=dict(seq_parallel_mode=sp_mode),
), ),
], ],
) )
@ -154,10 +171,12 @@ class Qwen2Policy(Policy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="input_layernorm", suffix="input_layernorm",
target_module=norm_cls, target_module=norm_cls,
kwargs={"sp_partial_derived": sp_partial_derived},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="post_attention_layernorm", suffix="post_attention_layernorm",
target_module=norm_cls, target_module=norm_cls,
kwargs={"sp_partial_derived": sp_partial_derived},
), ),
], ],
policy=policy, policy=policy,
@ -168,16 +187,16 @@ class Qwen2Policy(Policy):
description=SubModuleReplacementDescription( description=SubModuleReplacementDescription(
suffix="norm", suffix="norm",
target_module=norm_cls, target_module=norm_cls,
kwargs={"sp_partial_derived": sp_partial_derived},
), ),
policy=policy, policy=policy,
target_key=Qwen2Model, target_key=Qwen2Model,
) )
# use flash attention if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
if self.shard_config.enable_flash_attention:
self.append_or_create_method_replacement( self.append_or_create_method_replacement(
description={ description={
"forward": get_qwen2_flash_attention_forward(self.shard_config), "forward": get_qwen2_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
}, },
policy=policy, policy=policy,
target_key=attn_cls, target_key=attn_cls,
@ -186,7 +205,9 @@ class Qwen2Policy(Policy):
# replace qwen2 model forward method # replace qwen2 model forward method
self.append_or_create_method_replacement( self.append_or_create_method_replacement(
description={ description={
"forward": get_qwen2_model_forward_for_flash_attn(self.shard_config), "forward": get_qwen2_model_forward_for_flash_attn(
self.shard_config, sp_mode, sp_size, sp_group
),
}, },
policy=policy, policy=policy,
target_key=Qwen2Model, target_key=Qwen2Model,

View File

@ -136,6 +136,44 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
@parameterize( @parameterize(
"test_config", "test_config",
[ [
{ # Ulysess + Flash attention
"tp_size": 1,
"pp_size": 2,
"sp_size": 2,
"num_microbatches": 2,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
"enable_flash_attention": True,
"use_lazy_init": True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 2,
"sp_size": 2,
"num_microbatches": 2,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "split_gather",
"enable_flash_attention": True,
"use_lazy_init": True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 1,
"pp_size": 1,
"sp_size": 2,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
"use_lazy_init": True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
{ {
"tp_size": 4, "tp_size": 4,
"pp_size": 1, "pp_size": 1,

View File

@ -58,6 +58,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# Check the grad when using ZeRO-1 and ZeRO-2 # Check the grad when using ZeRO-1 and ZeRO-2
if ( if (
booster.plugin.zero_stage in [1, 2] booster.plugin.zero_stage in [1, 2]
and booster.plugin.shard_config.pipeline_stage_manager is None
and booster.plugin.shard_config.enable_sequence_parallelism and booster.plugin.shard_config.enable_sequence_parallelism
and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all" and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all"
): ):
@ -154,6 +155,45 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
@parameterize( @parameterize(
"test_config", "test_config",
[ [
{ # Ulysess + Flash attention
"tp_size": 1,
"pp_size": 2,
"sp_size": 2,
"num_microbatches": 2,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
"enable_flash_attention": True,
"use_lazy_init": True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 2,
"sp_size": 2,
"num_microbatches": 2,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "split_gather",
"enable_flash_attention": True,
"use_lazy_init": True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 2,
"sp_size": 2,
"num_microbatches": 2,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "ring",
"enable_flash_attention": True,
"use_lazy_init": True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
{ {
"tp_size": 2, "tp_size": 2,
"pp_size": 1, "pp_size": 1,

View File

@ -180,6 +180,68 @@ def run_qwen2_test(test_config):
"zero_stage": 1, "zero_stage": 1,
"initial_scale": 1, "initial_scale": 1,
}, },
{ # Ulysess + Flash attention
"tp_size": 1,
"pp_size": 2,
"sp_size": 2,
"num_microbatches": 2,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
"enable_flash_attention": True,
"use_lazy_init": True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 2,
"sp_size": 2,
"num_microbatches": 2,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "split_gather",
"enable_flash_attention": True,
"use_lazy_init": True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 2,
"sp_size": 2,
"num_microbatches": 2,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "ring",
"enable_flash_attention": True,
"use_lazy_init": True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 1,
"pp_size": 1,
"sp_size": 2,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
"use_lazy_init": True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 4,
"pp_size": 1,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "split_gather",
"enable_flash_attention": False,
"use_lazy_init": True,
"precision": "fp16",
"initial_scale": 1,
},
{ {
"tp_size": 2, "tp_size": 2,
"pp_size": 2, "pp_size": 2,