mirror of https://github.com/hpcaitech/ColossalAI
fix
parent
eb5ba40def
commit
193030f696
|
@ -1097,11 +1097,13 @@ def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1, fp8
|
|||
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim, fp8_communication)
|
||||
|
||||
|
||||
def gather_sp_output(hidden_states, sp_group, sp_mode):
|
||||
def gather_sp_output(hidden_states, sp_group, sp_mode, fp8_communication=False):
|
||||
"""
|
||||
Gather the output of the last layer for cross entropy computation
|
||||
"""
|
||||
# Rescale grad (HybridParallelPlugin applies ZeRO grad averaging on the DP * SP group)
|
||||
scale = None if is_share_sp_tp(sp_mode) else dist.get_world_size(sp_group)
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=scale)
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, grad_scale=scale, fp8_communication=fp8_communication
|
||||
)
|
||||
return hidden_states
|
||||
|
|
|
@ -26,11 +26,7 @@ from transformers.utils import logging
|
|||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.layer import AttnMaskType
|
||||
from colossalai.shardformer.layer._operation import (
|
||||
all_to_all_comm,
|
||||
gather_forward_split_backward,
|
||||
split_forward_gather_backward,
|
||||
)
|
||||
from colossalai.shardformer.layer._operation import all_to_all_comm, gather_sp_output, split_forward_gather_backward
|
||||
from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
|
@ -235,12 +231,8 @@ class LlamaPipelineForwards:
|
|||
if stage_manager.is_last_stage():
|
||||
hidden_states = self.norm(hidden_states)
|
||||
if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode):
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
|
||||
hidden_states = gather_sp_output(
|
||||
hidden_states, sp_group, sp_mode, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
|
@ -546,7 +538,6 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
|||
|
||||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
bsz, q_len, _ = query_states.size()
|
||||
|
@ -565,6 +556,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
|||
)
|
||||
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
|
@ -683,7 +675,7 @@ def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=N
|
|||
|
||||
past_seen_tokens = 0
|
||||
seq_len = inputs_embeds.shape[1]
|
||||
inputs_embeds.shape[0]
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
if use_cache: # kept for BC (cache positions)
|
||||
if not isinstance(past_key_values, StaticCache):
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
|
@ -697,7 +689,7 @@ def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=N
|
|||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
if shard_config.enable_flash_attention:
|
||||
mask_shape = (inputs_embeds.shape[0], 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len)
|
||||
mask_shape = (batch_size, 1, seq_len, past_seen_tokens + seq_len)
|
||||
attn_kwargs: dict = ColoAttention.prepare_attn_kwargs(
|
||||
mask_shape,
|
||||
inputs_embeds.dtype,
|
||||
|
@ -771,14 +763,10 @@ def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=N
|
|||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if sp_mode == "ring" or sp_mode == "split_gather":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
|
||||
# Cases that don't support parallelizing cross entropy computation along sequence
|
||||
if (not shard_config.parallel_output) or is_share_sp_tp(sp_mode) or force_sp_output_gather:
|
||||
hidden_states = gather_sp_output(
|
||||
hidden_states, sp_group, sp_mode, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
|
|
Loading…
Reference in New Issue