@ -24,6 +24,7 @@ from colossalai.moe._operation import (
all_to_all_uneven ,
)
from colossalai . pipeline . stage_manager import PipelineStageManager
from colossalai . quantization . fp8 import all_reduce_fp8
from colossalai . shardformer . layer . _operation import (
all_to_all_comm ,
gather_forward_split_backward ,
@ -61,7 +62,13 @@ class EPDeepseekMoE(nn.Module):
def __init__ ( self ) :
raise RuntimeError ( f " Please use `from_native_module` to create an instance of { self . __class__ . __name__ } " )
def setup_process_groups ( self , tp_group : ProcessGroup , moe_dp_group : ProcessGroup , ep_group : ProcessGroup ) :
def setup_process_groups (
self ,
tp_group : ProcessGroup ,
moe_dp_group : ProcessGroup ,
ep_group : ProcessGroup ,
fp8_communication : bool = False ,
) :
assert tp_group is not None
assert moe_dp_group is not None
assert ep_group is not None
@ -70,6 +77,7 @@ class EPDeepseekMoE(nn.Module):
self . ep_rank = dist . get_rank ( ep_group )
self . num_experts = self . config . n_routed_experts
assert self . num_experts % self . ep_size == 0
self . fp8_communication = fp8_communication
self . ep_group = ep_group
self . num_experts_per_ep = self . num_experts / / self . ep_size
@ -86,9 +94,15 @@ class EPDeepseekMoE(nn.Module):
self . tp_group = tp_group
if self . tp_group . size ( ) > 1 :
for expert in held_experts :
expert . gate_proj = Linear1D_Col . from_native_module ( expert . gate_proj , self . tp_group )
expert . up_proj = Linear1D_Col . from_native_module ( expert . up_proj , self . tp_group )
expert . down_proj = Linear1D_Row . from_native_module ( expert . down_proj , self . tp_group )
expert . gate_proj = Linear1D_Col . from_native_module (
expert . gate_proj , self . tp_group , fp8_communication = self . fp8_communication
)
expert . up_proj = Linear1D_Col . from_native_module (
expert . up_proj , self . tp_group , fp8_communication = self . fp8_communication
)
expert . down_proj = Linear1D_Row . from_native_module (
expert . down_proj , self . tp_group , fp8_communication = self . fp8_communication
)
for p in self . experts . parameters ( ) :
set_moe_tensor_ep_group ( p , ep_group )
@ -106,7 +120,8 @@ class EPDeepseekMoE(nn.Module):
if module . __class__ . __name__ == " DeepseekMLP " :
return module
module . __class__ = EPDeepseekMoE
module . setup_process_groups ( tp_group , moe_dp_group , ep_group )
fp8_communication = kwargs . get ( " fp8_communication " , False )
module . setup_process_groups ( tp_group , moe_dp_group , ep_group , fp8_communication = fp8_communication )
return module
def forward ( self , hidden_states : torch . Tensor ) - > torch . Tensor :
@ -137,11 +152,21 @@ class EPDeepseekMoE(nn.Module):
for i in range ( 1 , self . ep_size ) :
activate_experts + = output_split_sizes [ i * self . num_experts_per_ep : ( i + 1 ) * self . num_experts_per_ep ]
activate_experts = ( activate_experts > 0 ) . float ( )
dist . all_reduce ( activate_experts , group = self . moe_dp_group )
if self . fp8_communication :
all_reduce_fp8 ( activate_experts , group = self . moe_dp_group )
else :
dist . all_reduce ( activate_experts , group = self . moe_dp_group )
input_split_list = input_split_sizes . view ( self . ep_size , self . num_experts_per_ep ) . sum ( dim = - 1 ) . tolist ( )
output_split_list = output_split_sizes . view ( self . ep_size , self . num_experts_per_ep ) . sum ( dim = - 1 ) . tolist ( )
output_states , _ = all_to_all_uneven ( dispatch_states , input_split_list , output_split_list , self . ep_group )
output_states , _ = all_to_all_uneven (
dispatch_states ,
input_split_list ,
output_split_list ,
self . ep_group ,
fp8_communication = self . fp8_communication ,
)
output_states = EPGradScalerIn . apply ( output_states , self . ep_size )
if output_states . size ( 0 ) > 0 :
@ -167,7 +192,9 @@ class EPDeepseekMoE(nn.Module):
output_states_list . append ( split_states )
output_states = torch . cat ( output_states_list )
output_states = EPGradScalerOut . apply ( output_states , self . ep_size )
dispatch_states , _ = all_to_all_uneven ( output_states , output_split_list , input_split_list , self . ep_group )
dispatch_states , _ = all_to_all_uneven (
output_states , output_split_list , input_split_list , self . ep_group , fp8_communication = self . fp8_communication
)
recover_token_idx = torch . empty_like ( flat_topk_token_idx )
recover_token_idx [ flat_topk_token_idx ] = torch . arange (
flat_topk_token_idx . size ( 0 ) , device = flat_topk_token_idx . device
@ -534,9 +561,9 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == " all_to_all " :
query_states = all_to_all_comm ( query_states , sp_group )
key_states = all_to_all_comm ( key_states , sp_group )
value_states = all_to_all_comm ( value_states , sp_group )
query_states = all_to_all_comm ( query_states , sp_group , fp8_communication = shard_config . fp8_communication )
key_states = all_to_all_comm ( key_states , sp_group , fp8_communication = shard_config . fp8_communication )
value_states = all_to_all_comm ( value_states , sp_group , fp8_communication = shard_config . fp8_communication )
bsz , q_len , _ = query_states . size ( )
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
@ -595,7 +622,9 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == " all_to_all " :
attn_output = attn_output . reshape ( bsz , q_len , self . num_heads * self . head_dim ) . contiguous ( ) # (1, 8, 128)
attn_output = all_to_all_comm ( attn_output , sp_group , scatter_dim = 1 , gather_dim = 2 ) # (1, 4, 256)
attn_output = all_to_all_comm (
attn_output , sp_group , scatter_dim = 1 , gather_dim = 2 , fp8_communication = shard_config . fp8_communication
) # (1, 4, 256)
else :
attn_output = attn_output . reshape ( bsz , q_len , self . hidden_size )
@ -685,9 +714,13 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si
)
if sp_mode in [ " ring " , " split_gather " ] :
inputs_embeds = split_forward_gather_backward ( inputs_embeds , 1 , sp_group )
inputs_embeds = split_forward_gather_backward (
inputs_embeds , 1 , sp_group , fp8_communication = shard_config . fp8_communication
)
elif sp_mode == " all_to_all " :
inputs_embeds = split_forward_gather_backward ( inputs_embeds , 1 , sp_group , 1 / sp_size )
inputs_embeds = split_forward_gather_backward (
inputs_embeds , 1 , sp_group , 1 / sp_size , fp8_communication = shard_config . fp8_communication
)
# embed positions
hidden_states = inputs_embeds
@ -731,9 +764,13 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si
hidden_states = self . norm ( hidden_states )
if sp_mode == " ring " or sp_mode == " split_gather " :
hidden_states = gather_forward_split_backward ( hidden_states , 1 , sp_group )
hidden_states = gather_forward_split_backward (
hidden_states , 1 , sp_group , fp8_communication = shard_config . fp8_communication
)
elif sp_mode == " all_to_all " :
hidden_states = gather_forward_split_backward ( hidden_states , 1 , sp_group , grad_scale = sp_size )
hidden_states = gather_forward_split_backward (
hidden_states , 1 , sp_group , grad_scale = sp_size , fp8_communication = shard_config . fp8_communication
)
# add hidden states from the last decoder layer
if output_hidden_states :