[fp8]Moe support fp8 communication (#5977)

* fix

* support moe fp8

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

fix

fi

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/5998/head
flybird11111 2024-08-09 18:26:02 +08:00 committed by GitHub
parent e4aadeee20
commit f1a3a326c4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 160 additions and 52 deletions

View File

@ -214,6 +214,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
moe_dp_outside: bool = True, moe_dp_outside: bool = True,
overlap_p2p: bool = True, overlap_p2p: bool = True,
overlap_allgather: bool = False, overlap_allgather: bool = False,
fp8_communication: bool = False,
) -> None: ) -> None:
if overlap_communication or zero_stage == 2: if overlap_communication or zero_stage == 2:
overlap_communication = False overlap_communication = False
@ -341,6 +342,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
parallel_output=parallel_output, parallel_output=parallel_output,
make_vocab_size_divisible_by=make_vocab_size_divisible_by, make_vocab_size_divisible_by=make_vocab_size_divisible_by,
gradient_checkpoint_config=gradient_checkpoint_config, gradient_checkpoint_config=gradient_checkpoint_config,
fp8_communication=fp8_communication,
) )
self.amp_config = dict( self.amp_config = dict(
initial_scale=initial_scale, initial_scale=initial_scale,

View File

@ -6,6 +6,8 @@ from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd from torch.cuda.amp import custom_bwd, custom_fwd
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from colossalai.quantization.fp8 import all_to_all_single_fp8
MOE_KERNEL = None MOE_KERNEL = None
@ -380,6 +382,7 @@ def _all_to_all(
output_split_sizes: Optional[List[int]] = None, output_split_sizes: Optional[List[int]] = None,
group=None, group=None,
async_op: bool = False, async_op: bool = False,
fp8_communication: bool = False,
): ):
""" """
Returns: Returns:
@ -392,9 +395,14 @@ def _all_to_all(
outputs = torch.empty(outputs_shape, dtype=inputs.dtype, device=inputs.device) outputs = torch.empty(outputs_shape, dtype=inputs.dtype, device=inputs.device)
inputs = inputs.contiguous() inputs = inputs.contiguous()
outputs = outputs.contiguous() outputs = outputs.contiguous()
handle = dist.all_to_all_single( if fp8_communication:
outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=async_op handle = all_to_all_single_fp8(
) outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=False
)
else:
handle = dist.all_to_all_single(
outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=async_op
)
return outputs, handle return outputs, handle
@ -407,6 +415,7 @@ class AllToAllUneven(torch.autograd.Function):
output_split_sizes=None, output_split_sizes=None,
group=None, group=None,
overlap: bool = False, overlap: bool = False,
fp8_communication: bool = False,
): ):
""" """
Returns: Returns:
@ -416,7 +425,9 @@ class AllToAllUneven(torch.autograd.Function):
ctx.input_split_sizes = input_split_sizes ctx.input_split_sizes = input_split_sizes
ctx.output_split_sizes = output_split_sizes ctx.output_split_sizes = output_split_sizes
ctx.group = group ctx.group = group
return _all_to_all(inputs, input_split_sizes, output_split_sizes, group, overlap) return _all_to_all(
inputs, input_split_sizes, output_split_sizes, group, overlap, fp8_communication=fp8_communication
)
@staticmethod @staticmethod
def backward(ctx: Any, *grad_outputs): def backward(ctx: Any, *grad_outputs):
@ -426,6 +437,7 @@ class AllToAllUneven(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
) )
@ -435,8 +447,9 @@ def all_to_all_uneven(
output_split_sizes: Optional[List[int]] = None, output_split_sizes: Optional[List[int]] = None,
group=None, group=None,
overlap: bool = False, overlap: bool = False,
fp8_communication: bool = False,
): ):
assert ( assert (
inputs.requires_grad inputs.requires_grad
), "Input must require grad to assure that backward is executed, otherwise it might hang the program." ), "Input must require grad to assure that backward is executed, otherwise it might hang the program."
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap) return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap, fp8_communication)

View File

@ -27,16 +27,19 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2 fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
fp8_max = torch.finfo(fp8_type).max fp8_max = torch.finfo(fp8_type).max
if per_channel_scale: if inp.numel() == 0:
per_channel_max = inp.abs().max(dim=-1).values.float() return inp.to(fp8_type), torch.tensor([1.0], device=inp.device)
per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0)
scale = fp8_max / per_channel_max[:, None]
scale_inv = per_channel_max / fp8_max
else: else:
per_tensor_max = inp.abs().max().float() if per_channel_scale:
per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0) per_channel_max = inp.abs().max(dim=-1).values.float()
scale = fp8_max / per_tensor_max per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0)
scale_inv = 1.0 / scale scale = fp8_max / per_channel_max[:, None]
scale_inv = per_channel_max / fp8_max
else:
per_tensor_max = inp.abs().max().float()
per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0)
scale = fp8_max / per_tensor_max
scale_inv = 1.0 / scale
ret = (scale * inp.float()).to(fp8_type) ret = (scale * inp.float()).to(fp8_type)
return ret, scale_inv return ret, scale_inv
@ -113,7 +116,7 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, gro
tensor_list = [torch.empty_like(summed_out_fp8.view(torch.uint8)) for _ in range(world_size)] tensor_list = [torch.empty_like(summed_out_fp8.view(torch.uint8)) for _ in range(world_size)]
dist.all_gather(tensor_list, summed_out_fp8.view(torch.uint8), group=group) dist.all_gather(tensor_list, summed_out_fp8.view(torch.uint8), group=group)
for i in range(world_size): for i in range(world_size):
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i] tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i].to(input_device)
out = torch.cat(tensor_list, dim=0) out = torch.cat(tensor_list, dim=0)
tensor.copy_(out[:input_size].view(input_shape).to(input_type)) tensor.copy_(out[:input_size].view(input_shape).to(input_type))

View File

@ -274,6 +274,7 @@ class VocabParallelEmbedding1D(PaddingParallelModule):
weight: Optional[nn.Parameter] = None, weight: Optional[nn.Parameter] = None,
weight_initializer: Callable = init.normal_(), weight_initializer: Callable = init.normal_(),
make_vocab_size_divisible_by: int = 64, make_vocab_size_divisible_by: int = 64,
fp8_communication: bool = False,
*args, *args,
**kwargs, **kwargs,
): ):
@ -282,6 +283,7 @@ class VocabParallelEmbedding1D(PaddingParallelModule):
self.embed_args = args self.embed_args = args
self.embed_kwargs = kwargs self.embed_kwargs = kwargs
self.process_group = process_group self.process_group = process_group
self.fp8_communication = fp8_communication
tensor_parallel_size = dist.get_world_size(group=process_group) tensor_parallel_size = dist.get_world_size(group=process_group)
tensor_parallel_rank = dist.get_rank(group=process_group) tensor_parallel_rank = dist.get_rank(group=process_group)
@ -390,5 +392,5 @@ class VocabParallelEmbedding1D(PaddingParallelModule):
embedding_output = output_parallel.clone() embedding_output = output_parallel.clone()
embedding_output[input_mask, :] = 0.0 embedding_output[input_mask, :] = 0.0
# Reduce across all the model parallel GPUs. # Reduce across all the model parallel GPUs.
output = reduce_forward(embedding_output, self.process_group) output = reduce_forward(embedding_output, self.process_group, fp8_communication=self.fp8_communication)
return output return output

View File

@ -24,6 +24,7 @@ from colossalai.moe._operation import (
all_to_all_uneven, all_to_all_uneven,
) )
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.quantization.fp8 import all_reduce_fp8
from colossalai.shardformer.layer._operation import ( from colossalai.shardformer.layer._operation import (
all_to_all_comm, all_to_all_comm,
gather_forward_split_backward, gather_forward_split_backward,
@ -61,7 +62,13 @@ class EPDeepseekMoE(nn.Module):
def __init__(self): def __init__(self):
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}") raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup): def setup_process_groups(
self,
tp_group: ProcessGroup,
moe_dp_group: ProcessGroup,
ep_group: ProcessGroup,
fp8_communication: bool = False,
):
assert tp_group is not None assert tp_group is not None
assert moe_dp_group is not None assert moe_dp_group is not None
assert ep_group is not None assert ep_group is not None
@ -70,6 +77,7 @@ class EPDeepseekMoE(nn.Module):
self.ep_rank = dist.get_rank(ep_group) self.ep_rank = dist.get_rank(ep_group)
self.num_experts = self.config.n_routed_experts self.num_experts = self.config.n_routed_experts
assert self.num_experts % self.ep_size == 0 assert self.num_experts % self.ep_size == 0
self.fp8_communication = fp8_communication
self.ep_group = ep_group self.ep_group = ep_group
self.num_experts_per_ep = self.num_experts // self.ep_size self.num_experts_per_ep = self.num_experts // self.ep_size
@ -86,9 +94,15 @@ class EPDeepseekMoE(nn.Module):
self.tp_group = tp_group self.tp_group = tp_group
if self.tp_group.size() > 1: if self.tp_group.size() > 1:
for expert in held_experts: for expert in held_experts:
expert.gate_proj = Linear1D_Col.from_native_module(expert.gate_proj, self.tp_group) expert.gate_proj = Linear1D_Col.from_native_module(
expert.up_proj = Linear1D_Col.from_native_module(expert.up_proj, self.tp_group) expert.gate_proj, self.tp_group, fp8_communication=self.fp8_communication
expert.down_proj = Linear1D_Row.from_native_module(expert.down_proj, self.tp_group) )
expert.up_proj = Linear1D_Col.from_native_module(
expert.up_proj, self.tp_group, fp8_communication=self.fp8_communication
)
expert.down_proj = Linear1D_Row.from_native_module(
expert.down_proj, self.tp_group, fp8_communication=self.fp8_communication
)
for p in self.experts.parameters(): for p in self.experts.parameters():
set_moe_tensor_ep_group(p, ep_group) set_moe_tensor_ep_group(p, ep_group)
@ -106,7 +120,8 @@ class EPDeepseekMoE(nn.Module):
if module.__class__.__name__ == "DeepseekMLP": if module.__class__.__name__ == "DeepseekMLP":
return module return module
module.__class__ = EPDeepseekMoE module.__class__ = EPDeepseekMoE
module.setup_process_groups(tp_group, moe_dp_group, ep_group) fp8_communication = kwargs.get("fp8_communication", False)
module.setup_process_groups(tp_group, moe_dp_group, ep_group, fp8_communication=fp8_communication)
return module return module
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@ -137,11 +152,21 @@ class EPDeepseekMoE(nn.Module):
for i in range(1, self.ep_size): for i in range(1, self.ep_size):
activate_experts += output_split_sizes[i * self.num_experts_per_ep : (i + 1) * self.num_experts_per_ep] activate_experts += output_split_sizes[i * self.num_experts_per_ep : (i + 1) * self.num_experts_per_ep]
activate_experts = (activate_experts > 0).float() activate_experts = (activate_experts > 0).float()
dist.all_reduce(activate_experts, group=self.moe_dp_group)
if self.fp8_communication:
all_reduce_fp8(activate_experts, group=self.moe_dp_group)
else:
dist.all_reduce(activate_experts, group=self.moe_dp_group)
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group) output_states, _ = all_to_all_uneven(
dispatch_states,
input_split_list,
output_split_list,
self.ep_group,
fp8_communication=self.fp8_communication,
)
output_states = EPGradScalerIn.apply(output_states, self.ep_size) output_states = EPGradScalerIn.apply(output_states, self.ep_size)
if output_states.size(0) > 0: if output_states.size(0) > 0:
@ -167,7 +192,9 @@ class EPDeepseekMoE(nn.Module):
output_states_list.append(split_states) output_states_list.append(split_states)
output_states = torch.cat(output_states_list) output_states = torch.cat(output_states_list)
output_states = EPGradScalerOut.apply(output_states, self.ep_size) output_states = EPGradScalerOut.apply(output_states, self.ep_size)
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group) dispatch_states, _ = all_to_all_uneven(
output_states, output_split_list, input_split_list, self.ep_group, fp8_communication=self.fp8_communication
)
recover_token_idx = torch.empty_like(flat_topk_token_idx) recover_token_idx = torch.empty_like(flat_topk_token_idx)
recover_token_idx[flat_topk_token_idx] = torch.arange( recover_token_idx[flat_topk_token_idx] = torch.arange(
flat_topk_token_idx.size(0), device=flat_topk_token_idx.device flat_topk_token_idx.size(0), device=flat_topk_token_idx.device
@ -534,9 +561,9 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non
# sp: all-to-all comminucation when introducing sequence parallel # sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all": if sp_mode == "all_to_all":
query_states = all_to_all_comm(query_states, sp_group) query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
key_states = all_to_all_comm(key_states, sp_group) key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
value_states = all_to_all_comm(value_states, sp_group) value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
bsz, q_len, _ = query_states.size() bsz, q_len, _ = query_states.size()
# Flash attention requires the input to have the shape # Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim # batch_size x seq_length x head_dim x hidden_dim
@ -595,7 +622,9 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non
# sp: all-to-all comminucation when introducing sequence parallel # sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all": if sp_mode == "all_to_all":
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128) attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128)
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) # (1, 4, 256) attn_output = all_to_all_comm(
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
) # (1, 4, 256)
else: else:
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
@ -685,9 +714,13 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si
) )
if sp_mode in ["ring", "split_gather"]: if sp_mode in ["ring", "split_gather"]:
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) inputs_embeds = split_forward_gather_backward(
inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all": elif sp_mode == "all_to_all":
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) inputs_embeds = split_forward_gather_backward(
inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
)
# embed positions # embed positions
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -731,9 +764,13 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
if sp_mode == "ring" or sp_mode == "split_gather": if sp_mode == "ring" or sp_mode == "split_gather":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all": elif sp_mode == "all_to_all":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
)
# add hidden states from the last decoder layer # add hidden states from the last decoder layer
if output_hidden_states: if output_hidden_states:

View File

@ -53,7 +53,13 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}") raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup): def setup_process_groups(
self,
tp_group: ProcessGroup,
moe_dp_group: ProcessGroup,
ep_group: ProcessGroup,
fp8_communication: bool = False,
):
assert tp_group is not None assert tp_group is not None
assert moe_dp_group is not None assert moe_dp_group is not None
assert ep_group is not None assert ep_group is not None
@ -62,6 +68,7 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
self.ep_size = dist.get_world_size(ep_group) self.ep_size = dist.get_world_size(ep_group)
self.ep_rank = dist.get_rank(ep_group) self.ep_rank = dist.get_rank(ep_group)
self.ep_group = ep_group self.ep_group = ep_group
self.fp8_communication = fp8_communication
if self.num_experts % self.ep_size != 0: if self.num_experts % self.ep_size != 0:
raise ValueError("The number of experts must be divisible by the number of expert parallel groups.") raise ValueError("The number of experts must be divisible by the number of expert parallel groups.")
@ -80,9 +87,15 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
self.tp_group = tp_group self.tp_group = tp_group
if self.tp_group.size() > 1: if self.tp_group.size() > 1:
for expert in held_experts: for expert in held_experts:
expert.w1 = Linear1D_Col.from_native_module(expert.w1, self.tp_group) expert.w1 = Linear1D_Col.from_native_module(
expert.w3 = Linear1D_Col.from_native_module(expert.w3, self.tp_group) expert.w1, self.tp_group, fp8_communication=self.fp8_communication
expert.w2 = Linear1D_Row.from_native_module(expert.w2, self.tp_group) )
expert.w3 = Linear1D_Col.from_native_module(
expert.w3, self.tp_group, fp8_communication=self.fp8_communication
)
expert.w2 = Linear1D_Row.from_native_module(
expert.w2, self.tp_group, fp8_communication=self.fp8_communication
)
for p in self.experts.parameters(): for p in self.experts.parameters():
set_moe_tensor_ep_group(p, ep_group) set_moe_tensor_ep_group(p, ep_group)
@ -99,7 +112,8 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
# TODO: better init # TODO: better init
LazyInitContext.materialize(module) LazyInitContext.materialize(module)
module.__class__ = EPMixtralSparseMoeBlock module.__class__ = EPMixtralSparseMoeBlock
module.setup_process_groups(tp_group, moe_dp_group, ep_group) fp8_communication = kwargs.get("fp8_communication", False)
module.setup_process_groups(tp_group, moe_dp_group, ep_group, fp8_communication)
return module return module
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@ -120,6 +134,7 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
input_split_sizes = selected_experts.bincount(minlength=self.num_experts) input_split_sizes = selected_experts.bincount(minlength=self.num_experts)
output_split_sizes = torch.zeros_like(input_split_sizes) output_split_sizes = torch.zeros_like(input_split_sizes)
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group) dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
with torch.no_grad(): with torch.no_grad():
@ -132,7 +147,13 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group) output_states, _ = all_to_all_uneven(
dispatch_states,
input_split_list,
output_split_list,
self.ep_group,
fp8_communication=self.fp8_communication,
)
# compute expert output # compute expert output
output_states = EPGradScalerIn.apply(output_states, self.ep_size) output_states = EPGradScalerIn.apply(output_states, self.ep_size)
if output_states.size(0) > 0: if output_states.size(0) > 0:
@ -162,7 +183,9 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
output_states = torch.cat(output_states_list) output_states = torch.cat(output_states_list)
output_states = EPGradScalerOut.apply(output_states, self.ep_size) output_states = EPGradScalerOut.apply(output_states, self.ep_size)
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group) dispatch_states, _ = all_to_all_uneven(
output_states, output_split_list, input_split_list, self.ep_group, fp8_communication=self.fp8_communication
)
recover_experts_idx = torch.empty_like(selected_experts_idx) recover_experts_idx = torch.empty_like(selected_experts_idx)
recover_experts_idx[selected_experts_idx] = torch.arange( recover_experts_idx[selected_experts_idx] = torch.arange(
@ -566,9 +589,9 @@ def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None
# sp: all-to-all comminucation when introducing sequence parallel # sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all": if sp_mode == "all_to_all":
query_states = all_to_all_comm(query_states, sp_group) query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
key_states = all_to_all_comm(key_states, sp_group) key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
value_states = all_to_all_comm(value_states, sp_group) value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
bsz, q_len, _ = query_states.size() bsz, q_len, _ = query_states.size()
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
@ -780,9 +803,13 @@ def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
) )
if sp_mode in ["ring", "split_gather"]: if sp_mode in ["ring", "split_gather"]:
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) inputs_embeds = split_forward_gather_backward(
inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all": elif sp_mode == "all_to_all":
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) inputs_embeds = split_forward_gather_backward(
inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
)
hidden_states = inputs_embeds hidden_states = inputs_embeds
# decoder layers # decoder layers
@ -831,9 +858,13 @@ def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
if sp_mode == "ring" or sp_mode == "split_gather": if sp_mode == "ring" or sp_mode == "split_gather":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all": elif sp_mode == "all_to_all":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
)
# add hidden states from the last decoder layer # add hidden states from the last decoder layer
if output_hidden_states: if output_hidden_states:

View File

@ -118,18 +118,22 @@ class DeepseekPolicy(Policy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn.q_proj", suffix="self_attn.q_proj",
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs={"fp8_communication": self.shard_config.fp8_communication},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn.k_proj", suffix="self_attn.k_proj",
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs={"fp8_communication": self.shard_config.fp8_communication},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn.v_proj", suffix="self_attn.v_proj",
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs={"fp8_communication": self.shard_config.fp8_communication},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn.o_proj", suffix="self_attn.o_proj",
target_module=Linear1D_Row, target_module=Linear1D_Row,
kwargs={"fp8_communication": self.shard_config.fp8_communication},
), ),
], ],
) )
@ -138,7 +142,10 @@ class DeepseekPolicy(Policy):
description=SubModuleReplacementDescription( description=SubModuleReplacementDescription(
suffix="embed_tokens", suffix="embed_tokens",
target_module=embedding_cls, target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, kwargs={
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
},
), ),
policy=policy, policy=policy,
target_key="DeepseekModel", target_key="DeepseekModel",
@ -155,6 +162,7 @@ class DeepseekPolicy(Policy):
"ep_group": self.shard_config.ep_group, "ep_group": self.shard_config.ep_group,
"tp_group": self.shard_config.tensor_parallel_process_group, "tp_group": self.shard_config.tensor_parallel_process_group,
"moe_dp_group": self.shard_config.moe_dp_group, "moe_dp_group": self.shard_config.moe_dp_group,
"fp8_communication": self.shard_config.fp8_communication,
}, },
) )
], ],
@ -305,7 +313,7 @@ class DeepseekForCausalLMPolicy(DeepseekPolicy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="lm_head", suffix="lm_head",
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs=dict(gather_output=True), kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
) )
] ]
) )

View File

@ -114,21 +114,27 @@ class MixtralPolicy(Policy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn.q_proj", suffix="self_attn.q_proj",
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs={"fp8_communication": self.shard_config.fp8_communication},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn.k_proj", suffix="self_attn.k_proj",
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs={"fp8_communication": self.shard_config.fp8_communication},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn.v_proj", suffix="self_attn.v_proj",
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs={"fp8_communication": self.shard_config.fp8_communication},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn.o_proj", suffix="self_attn.o_proj",
target_module=Linear1D_Row, target_module=Linear1D_Row,
kwargs={"fp8_communication": self.shard_config.fp8_communication},
), ),
SubModuleReplacementDescription( # or replicate? SubModuleReplacementDescription( # or replicate?
suffix="block_sparse_moe.gate", target_module=Linear1D_Col, kwargs={"gather_output": True} suffix="block_sparse_moe.gate",
target_module=Linear1D_Col,
kwargs={"gather_output": True, "fp8_communication": self.shard_config.fp8_communication},
), ),
], ],
) )
@ -138,7 +144,10 @@ class MixtralPolicy(Policy):
description=SubModuleReplacementDescription( description=SubModuleReplacementDescription(
suffix="embed_tokens", suffix="embed_tokens",
target_module=embedding_cls, target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, kwargs={
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
},
), ),
policy=policy, policy=policy,
target_key=MixtralModel, target_key=MixtralModel,
@ -155,6 +164,7 @@ class MixtralPolicy(Policy):
"ep_group": self.shard_config.ep_group, "ep_group": self.shard_config.ep_group,
"tp_group": self.shard_config.tensor_parallel_process_group, "tp_group": self.shard_config.tensor_parallel_process_group,
"moe_dp_group": self.shard_config.moe_dp_group, "moe_dp_group": self.shard_config.moe_dp_group,
"fp8_communication": self.shard_config.fp8_communication,
}, },
) )
], ],
@ -282,7 +292,7 @@ class MixtralForCausalLMPolicy(MixtralPolicy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="lm_head", suffix="lm_head",
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs=dict(gather_output=True), kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
) )
] ]
) )
@ -336,7 +346,9 @@ class MixtralForSequenceClassificationPolicy(MixtralPolicy):
MixtralForSequenceClassification: ModulePolicyDescription( MixtralForSequenceClassification: ModulePolicyDescription(
sub_module_replacement=[ sub_module_replacement=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True) suffix="score",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
) )
] ]
) )