pull/6023/head
wangbluo 2024-08-21 03:21:49 +00:00
parent eb5ba40def
commit 193030f696
2 changed files with 14 additions and 24 deletions

View File

@ -1097,11 +1097,13 @@ def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1, fp8
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim, fp8_communication) return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim, fp8_communication)
def gather_sp_output(hidden_states, sp_group, sp_mode): def gather_sp_output(hidden_states, sp_group, sp_mode, fp8_communication=False):
""" """
Gather the output of the last layer for cross entropy computation Gather the output of the last layer for cross entropy computation
""" """
# Rescale grad (HybridParallelPlugin applies ZeRO grad averaging on the DP * SP group) # Rescale grad (HybridParallelPlugin applies ZeRO grad averaging on the DP * SP group)
scale = None if is_share_sp_tp(sp_mode) else dist.get_world_size(sp_group) scale = None if is_share_sp_tp(sp_mode) else dist.get_world_size(sp_group)
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=scale) hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, grad_scale=scale, fp8_communication=fp8_communication
)
return hidden_states return hidden_states

View File

@ -26,11 +26,7 @@ from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import AttnMaskType from colossalai.shardformer.layer import AttnMaskType
from colossalai.shardformer.layer._operation import ( from colossalai.shardformer.layer._operation import all_to_all_comm, gather_sp_output, split_forward_gather_backward
all_to_all_comm,
gather_forward_split_backward,
split_forward_gather_backward,
)
from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig
@ -235,12 +231,8 @@ class LlamaPipelineForwards:
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode):
hidden_states = gather_forward_split_backward( hidden_states = gather_sp_output(
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication hidden_states, sp_group, sp_mode, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all":
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
) )
# add hidden states from the last decoder layer # add hidden states from the last decoder layer
@ -546,7 +538,6 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
# sp: all-to-all comminucation when introducing sequence parallel # sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all": if sp_mode == "all_to_all":
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication) key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication) value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
bsz, q_len, _ = query_states.size() bsz, q_len, _ = query_states.size()
@ -565,6 +556,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
) )
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, position_ids) cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
@ -683,7 +675,7 @@ def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=N
past_seen_tokens = 0 past_seen_tokens = 0
seq_len = inputs_embeds.shape[1] seq_len = inputs_embeds.shape[1]
inputs_embeds.shape[0] batch_size = inputs_embeds.shape[0]
if use_cache: # kept for BC (cache positions) if use_cache: # kept for BC (cache positions)
if not isinstance(past_key_values, StaticCache): if not isinstance(past_key_values, StaticCache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values = DynamicCache.from_legacy_cache(past_key_values)
@ -697,7 +689,7 @@ def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=N
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
if shard_config.enable_flash_attention: if shard_config.enable_flash_attention:
mask_shape = (inputs_embeds.shape[0], 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len) mask_shape = (batch_size, 1, seq_len, past_seen_tokens + seq_len)
attn_kwargs: dict = ColoAttention.prepare_attn_kwargs( attn_kwargs: dict = ColoAttention.prepare_attn_kwargs(
mask_shape, mask_shape,
inputs_embeds.dtype, inputs_embeds.dtype,
@ -771,14 +763,10 @@ def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=N
all_self_attns += (layer_outputs[1],) all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
# Cases that don't support parallelizing cross entropy computation along sequence
if sp_mode == "ring" or sp_mode == "split_gather": if (not shard_config.parallel_output) or is_share_sp_tp(sp_mode) or force_sp_output_gather:
hidden_states = gather_forward_split_backward( hidden_states = gather_sp_output(
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication hidden_states, sp_group, sp_mode, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all":
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
) )
# add hidden states from the last decoder layer # add hidden states from the last decoder layer