pull/6016/head
wangbluo 2024-08-21 03:21:49 +00:00
parent eb5ba40def
commit 193030f696
2 changed files with 14 additions and 24 deletions

View File

@ -1097,11 +1097,13 @@ def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1, fp8
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim, fp8_communication)
def gather_sp_output(hidden_states, sp_group, sp_mode):
def gather_sp_output(hidden_states, sp_group, sp_mode, fp8_communication=False):
"""
Gather the output of the last layer for cross entropy computation
"""
# Rescale grad (HybridParallelPlugin applies ZeRO grad averaging on the DP * SP group)
scale = None if is_share_sp_tp(sp_mode) else dist.get_world_size(sp_group)
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=scale)
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, grad_scale=scale, fp8_communication=fp8_communication
)
return hidden_states

View File

@ -26,11 +26,7 @@ from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import AttnMaskType
from colossalai.shardformer.layer._operation import (
all_to_all_comm,
gather_forward_split_backward,
split_forward_gather_backward,
)
from colossalai.shardformer.layer._operation import all_to_all_comm, gather_sp_output, split_forward_gather_backward
from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag
from colossalai.shardformer.shard import ShardConfig
@ -235,12 +231,8 @@ class LlamaPipelineForwards:
if stage_manager.is_last_stage():
hidden_states = self.norm(hidden_states)
if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode):
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all":
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
hidden_states = gather_sp_output(
hidden_states, sp_group, sp_mode, fp8_communication=shard_config.fp8_communication
)
# add hidden states from the last decoder layer
@ -546,7 +538,6 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
bsz, q_len, _ = query_states.size()
@ -565,6 +556,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
@ -683,7 +675,7 @@ def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=N
past_seen_tokens = 0
seq_len = inputs_embeds.shape[1]
inputs_embeds.shape[0]
batch_size = inputs_embeds.shape[0]
if use_cache: # kept for BC (cache positions)
if not isinstance(past_key_values, StaticCache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
@ -697,7 +689,7 @@ def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=N
position_ids = cache_position.unsqueeze(0)
if shard_config.enable_flash_attention:
mask_shape = (inputs_embeds.shape[0], 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len)
mask_shape = (batch_size, 1, seq_len, past_seen_tokens + seq_len)
attn_kwargs: dict = ColoAttention.prepare_attn_kwargs(
mask_shape,
inputs_embeds.dtype,
@ -771,14 +763,10 @@ def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=N
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
if sp_mode == "ring" or sp_mode == "split_gather":
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all":
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
# Cases that don't support parallelizing cross entropy computation along sequence
if (not shard_config.parallel_output) or is_share_sp_tp(sp_mode) or force_sp_output_gather:
hidden_states = gather_sp_output(
hidden_states, sp_group, sp_mode, fp8_communication=shard_config.fp8_communication
)
# add hidden states from the last decoder layer