From f1a3a326c47e48819be35f2d760192cea3e447c5 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 9 Aug 2024 18:26:02 +0800 Subject: [PATCH] [fp8]Moe support fp8 communication (#5977) * fix * support moe fp8 * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix fix fi * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../plugin/moe_hybrid_parallel_plugin.py | 2 + colossalai/moe/_operation.py | 23 +++++-- colossalai/quantization/fp8.py | 23 ++++--- colossalai/shardformer/layer/embedding.py | 4 +- colossalai/shardformer/modeling/deepseek.py | 69 ++++++++++++++----- colossalai/shardformer/modeling/mixtral.py | 59 ++++++++++++---- colossalai/shardformer/policies/deepseek.py | 12 +++- colossalai/shardformer/policies/mixtral.py | 20 ++++-- 8 files changed, 160 insertions(+), 52 deletions(-) diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index b3415af0e..ca3a68373 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -214,6 +214,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): moe_dp_outside: bool = True, overlap_p2p: bool = True, overlap_allgather: bool = False, + fp8_communication: bool = False, ) -> None: if overlap_communication or zero_stage == 2: overlap_communication = False @@ -341,6 +342,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): parallel_output=parallel_output, make_vocab_size_divisible_by=make_vocab_size_divisible_by, gradient_checkpoint_config=gradient_checkpoint_config, + fp8_communication=fp8_communication, ) self.amp_config = dict( initial_scale=initial_scale, diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py index ac422a4da..ba087a03b 100644 --- a/colossalai/moe/_operation.py +++ b/colossalai/moe/_operation.py @@ -6,6 +6,8 @@ from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd from torch.distributed import ProcessGroup +from colossalai.quantization.fp8 import all_to_all_single_fp8 + MOE_KERNEL = None @@ -380,6 +382,7 @@ def _all_to_all( output_split_sizes: Optional[List[int]] = None, group=None, async_op: bool = False, + fp8_communication: bool = False, ): """ Returns: @@ -392,9 +395,14 @@ def _all_to_all( outputs = torch.empty(outputs_shape, dtype=inputs.dtype, device=inputs.device) inputs = inputs.contiguous() outputs = outputs.contiguous() - handle = dist.all_to_all_single( - outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=async_op - ) + if fp8_communication: + handle = all_to_all_single_fp8( + outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=False + ) + else: + handle = dist.all_to_all_single( + outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=async_op + ) return outputs, handle @@ -407,6 +415,7 @@ class AllToAllUneven(torch.autograd.Function): output_split_sizes=None, group=None, overlap: bool = False, + fp8_communication: bool = False, ): """ Returns: @@ -416,7 +425,9 @@ class AllToAllUneven(torch.autograd.Function): ctx.input_split_sizes = input_split_sizes ctx.output_split_sizes = output_split_sizes ctx.group = group - return _all_to_all(inputs, input_split_sizes, output_split_sizes, group, overlap) + return _all_to_all( + inputs, input_split_sizes, output_split_sizes, group, overlap, fp8_communication=fp8_communication + ) @staticmethod def backward(ctx: Any, *grad_outputs): @@ -426,6 +437,7 @@ class AllToAllUneven(torch.autograd.Function): None, None, None, + None, ) @@ -435,8 +447,9 @@ def all_to_all_uneven( output_split_sizes: Optional[List[int]] = None, group=None, overlap: bool = False, + fp8_communication: bool = False, ): assert ( inputs.requires_grad ), "Input must require grad to assure that backward is executed, otherwise it might hang the program." - return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap) + return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap, fp8_communication) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index cfbf1fcf7..f2bffa09f 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -27,16 +27,19 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) - fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2 fp8_max = torch.finfo(fp8_type).max - if per_channel_scale: - per_channel_max = inp.abs().max(dim=-1).values.float() - per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0) - scale = fp8_max / per_channel_max[:, None] - scale_inv = per_channel_max / fp8_max + if inp.numel() == 0: + return inp.to(fp8_type), torch.tensor([1.0], device=inp.device) else: - per_tensor_max = inp.abs().max().float() - per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0) - scale = fp8_max / per_tensor_max - scale_inv = 1.0 / scale + if per_channel_scale: + per_channel_max = inp.abs().max(dim=-1).values.float() + per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0) + scale = fp8_max / per_channel_max[:, None] + scale_inv = per_channel_max / fp8_max + else: + per_tensor_max = inp.abs().max().float() + per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0) + scale = fp8_max / per_tensor_max + scale_inv = 1.0 / scale ret = (scale * inp.float()).to(fp8_type) return ret, scale_inv @@ -113,7 +116,7 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, gro tensor_list = [torch.empty_like(summed_out_fp8.view(torch.uint8)) for _ in range(world_size)] dist.all_gather(tensor_list, summed_out_fp8.view(torch.uint8), group=group) for i in range(world_size): - tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i] + tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i].to(input_device) out = torch.cat(tensor_list, dim=0) tensor.copy_(out[:input_size].view(input_shape).to(input_type)) diff --git a/colossalai/shardformer/layer/embedding.py b/colossalai/shardformer/layer/embedding.py index 9b77774aa..186063503 100644 --- a/colossalai/shardformer/layer/embedding.py +++ b/colossalai/shardformer/layer/embedding.py @@ -274,6 +274,7 @@ class VocabParallelEmbedding1D(PaddingParallelModule): weight: Optional[nn.Parameter] = None, weight_initializer: Callable = init.normal_(), make_vocab_size_divisible_by: int = 64, + fp8_communication: bool = False, *args, **kwargs, ): @@ -282,6 +283,7 @@ class VocabParallelEmbedding1D(PaddingParallelModule): self.embed_args = args self.embed_kwargs = kwargs self.process_group = process_group + self.fp8_communication = fp8_communication tensor_parallel_size = dist.get_world_size(group=process_group) tensor_parallel_rank = dist.get_rank(group=process_group) @@ -390,5 +392,5 @@ class VocabParallelEmbedding1D(PaddingParallelModule): embedding_output = output_parallel.clone() embedding_output[input_mask, :] = 0.0 # Reduce across all the model parallel GPUs. - output = reduce_forward(embedding_output, self.process_group) + output = reduce_forward(embedding_output, self.process_group, fp8_communication=self.fp8_communication) return output diff --git a/colossalai/shardformer/modeling/deepseek.py b/colossalai/shardformer/modeling/deepseek.py index a84a30972..48a629705 100644 --- a/colossalai/shardformer/modeling/deepseek.py +++ b/colossalai/shardformer/modeling/deepseek.py @@ -24,6 +24,7 @@ from colossalai.moe._operation import ( all_to_all_uneven, ) from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.quantization.fp8 import all_reduce_fp8 from colossalai.shardformer.layer._operation import ( all_to_all_comm, gather_forward_split_backward, @@ -61,7 +62,13 @@ class EPDeepseekMoE(nn.Module): def __init__(self): raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}") - def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup): + def setup_process_groups( + self, + tp_group: ProcessGroup, + moe_dp_group: ProcessGroup, + ep_group: ProcessGroup, + fp8_communication: bool = False, + ): assert tp_group is not None assert moe_dp_group is not None assert ep_group is not None @@ -70,6 +77,7 @@ class EPDeepseekMoE(nn.Module): self.ep_rank = dist.get_rank(ep_group) self.num_experts = self.config.n_routed_experts assert self.num_experts % self.ep_size == 0 + self.fp8_communication = fp8_communication self.ep_group = ep_group self.num_experts_per_ep = self.num_experts // self.ep_size @@ -86,9 +94,15 @@ class EPDeepseekMoE(nn.Module): self.tp_group = tp_group if self.tp_group.size() > 1: for expert in held_experts: - expert.gate_proj = Linear1D_Col.from_native_module(expert.gate_proj, self.tp_group) - expert.up_proj = Linear1D_Col.from_native_module(expert.up_proj, self.tp_group) - expert.down_proj = Linear1D_Row.from_native_module(expert.down_proj, self.tp_group) + expert.gate_proj = Linear1D_Col.from_native_module( + expert.gate_proj, self.tp_group, fp8_communication=self.fp8_communication + ) + expert.up_proj = Linear1D_Col.from_native_module( + expert.up_proj, self.tp_group, fp8_communication=self.fp8_communication + ) + expert.down_proj = Linear1D_Row.from_native_module( + expert.down_proj, self.tp_group, fp8_communication=self.fp8_communication + ) for p in self.experts.parameters(): set_moe_tensor_ep_group(p, ep_group) @@ -106,7 +120,8 @@ class EPDeepseekMoE(nn.Module): if module.__class__.__name__ == "DeepseekMLP": return module module.__class__ = EPDeepseekMoE - module.setup_process_groups(tp_group, moe_dp_group, ep_group) + fp8_communication = kwargs.get("fp8_communication", False) + module.setup_process_groups(tp_group, moe_dp_group, ep_group, fp8_communication=fp8_communication) return module def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -137,11 +152,21 @@ class EPDeepseekMoE(nn.Module): for i in range(1, self.ep_size): activate_experts += output_split_sizes[i * self.num_experts_per_ep : (i + 1) * self.num_experts_per_ep] activate_experts = (activate_experts > 0).float() - dist.all_reduce(activate_experts, group=self.moe_dp_group) + + if self.fp8_communication: + all_reduce_fp8(activate_experts, group=self.moe_dp_group) + else: + dist.all_reduce(activate_experts, group=self.moe_dp_group) input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() - output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group) + output_states, _ = all_to_all_uneven( + dispatch_states, + input_split_list, + output_split_list, + self.ep_group, + fp8_communication=self.fp8_communication, + ) output_states = EPGradScalerIn.apply(output_states, self.ep_size) if output_states.size(0) > 0: @@ -167,7 +192,9 @@ class EPDeepseekMoE(nn.Module): output_states_list.append(split_states) output_states = torch.cat(output_states_list) output_states = EPGradScalerOut.apply(output_states, self.ep_size) - dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group) + dispatch_states, _ = all_to_all_uneven( + output_states, output_split_list, input_split_list, self.ep_group, fp8_communication=self.fp8_communication + ) recover_token_idx = torch.empty_like(flat_topk_token_idx) recover_token_idx[flat_topk_token_idx] = torch.arange( flat_topk_token_idx.size(0), device=flat_topk_token_idx.device @@ -534,9 +561,9 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": - query_states = all_to_all_comm(query_states, sp_group) - key_states = all_to_all_comm(key_states, sp_group) - value_states = all_to_all_comm(value_states, sp_group) + query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication) + key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication) + value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication) bsz, q_len, _ = query_states.size() # Flash attention requires the input to have the shape # batch_size x seq_length x head_dim x hidden_dim @@ -595,7 +622,9 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128) - attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) # (1, 4, 256) + attn_output = all_to_all_comm( + attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication + ) # (1, 4, 256) else: attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) @@ -685,9 +714,13 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si ) if sp_mode in ["ring", "split_gather"]: - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) + inputs_embeds = split_forward_gather_backward( + inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication + ) elif sp_mode == "all_to_all": - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) + inputs_embeds = split_forward_gather_backward( + inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication + ) # embed positions hidden_states = inputs_embeds @@ -731,9 +764,13 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si hidden_states = self.norm(hidden_states) if sp_mode == "ring" or sp_mode == "split_gather": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) + hidden_states = gather_forward_split_backward( + hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication + ) elif sp_mode == "all_to_all": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) + hidden_states = gather_forward_split_backward( + hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication + ) # add hidden states from the last decoder layer if output_hidden_states: diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py index d30ce5ea8..76b441fe4 100644 --- a/colossalai/shardformer/modeling/mixtral.py +++ b/colossalai/shardformer/modeling/mixtral.py @@ -53,7 +53,13 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): def __init__(self, *args, **kwargs): raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}") - def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup): + def setup_process_groups( + self, + tp_group: ProcessGroup, + moe_dp_group: ProcessGroup, + ep_group: ProcessGroup, + fp8_communication: bool = False, + ): assert tp_group is not None assert moe_dp_group is not None assert ep_group is not None @@ -62,6 +68,7 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): self.ep_size = dist.get_world_size(ep_group) self.ep_rank = dist.get_rank(ep_group) self.ep_group = ep_group + self.fp8_communication = fp8_communication if self.num_experts % self.ep_size != 0: raise ValueError("The number of experts must be divisible by the number of expert parallel groups.") @@ -80,9 +87,15 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): self.tp_group = tp_group if self.tp_group.size() > 1: for expert in held_experts: - expert.w1 = Linear1D_Col.from_native_module(expert.w1, self.tp_group) - expert.w3 = Linear1D_Col.from_native_module(expert.w3, self.tp_group) - expert.w2 = Linear1D_Row.from_native_module(expert.w2, self.tp_group) + expert.w1 = Linear1D_Col.from_native_module( + expert.w1, self.tp_group, fp8_communication=self.fp8_communication + ) + expert.w3 = Linear1D_Col.from_native_module( + expert.w3, self.tp_group, fp8_communication=self.fp8_communication + ) + expert.w2 = Linear1D_Row.from_native_module( + expert.w2, self.tp_group, fp8_communication=self.fp8_communication + ) for p in self.experts.parameters(): set_moe_tensor_ep_group(p, ep_group) @@ -99,7 +112,8 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): # TODO: better init LazyInitContext.materialize(module) module.__class__ = EPMixtralSparseMoeBlock - module.setup_process_groups(tp_group, moe_dp_group, ep_group) + fp8_communication = kwargs.get("fp8_communication", False) + module.setup_process_groups(tp_group, moe_dp_group, ep_group, fp8_communication) return module def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -120,6 +134,7 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): input_split_sizes = selected_experts.bincount(minlength=self.num_experts) output_split_sizes = torch.zeros_like(input_split_sizes) + dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group) with torch.no_grad(): @@ -132,7 +147,13 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() - output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group) + output_states, _ = all_to_all_uneven( + dispatch_states, + input_split_list, + output_split_list, + self.ep_group, + fp8_communication=self.fp8_communication, + ) # compute expert output output_states = EPGradScalerIn.apply(output_states, self.ep_size) if output_states.size(0) > 0: @@ -162,7 +183,9 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): output_states = torch.cat(output_states_list) output_states = EPGradScalerOut.apply(output_states, self.ep_size) - dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group) + dispatch_states, _ = all_to_all_uneven( + output_states, output_split_list, input_split_list, self.ep_group, fp8_communication=self.fp8_communication + ) recover_experts_idx = torch.empty_like(selected_experts_idx) recover_experts_idx[selected_experts_idx] = torch.arange( @@ -566,9 +589,9 @@ def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": - query_states = all_to_all_comm(query_states, sp_group) - key_states = all_to_all_comm(key_states, sp_group) - value_states = all_to_all_comm(value_states, sp_group) + query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication) + key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication) + value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication) bsz, q_len, _ = query_states.size() query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) @@ -780,9 +803,13 @@ def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz ) if sp_mode in ["ring", "split_gather"]: - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) + inputs_embeds = split_forward_gather_backward( + inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication + ) elif sp_mode == "all_to_all": - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) + inputs_embeds = split_forward_gather_backward( + inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication + ) hidden_states = inputs_embeds # decoder layers @@ -831,9 +858,13 @@ def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz hidden_states = self.norm(hidden_states) if sp_mode == "ring" or sp_mode == "split_gather": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) + hidden_states = gather_forward_split_backward( + hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication + ) elif sp_mode == "all_to_all": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) + hidden_states = gather_forward_split_backward( + hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication + ) # add hidden states from the last decoder layer if output_hidden_states: diff --git a/colossalai/shardformer/policies/deepseek.py b/colossalai/shardformer/policies/deepseek.py index 605f69c4a..0b8a602d1 100644 --- a/colossalai/shardformer/policies/deepseek.py +++ b/colossalai/shardformer/policies/deepseek.py @@ -118,18 +118,22 @@ class DeepseekPolicy(Policy): SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=Linear1D_Col, + kwargs={"fp8_communication": self.shard_config.fp8_communication}, ), SubModuleReplacementDescription( suffix="self_attn.k_proj", target_module=Linear1D_Col, + kwargs={"fp8_communication": self.shard_config.fp8_communication}, ), SubModuleReplacementDescription( suffix="self_attn.v_proj", target_module=Linear1D_Col, + kwargs={"fp8_communication": self.shard_config.fp8_communication}, ), SubModuleReplacementDescription( suffix="self_attn.o_proj", target_module=Linear1D_Row, + kwargs={"fp8_communication": self.shard_config.fp8_communication}, ), ], ) @@ -138,7 +142,10 @@ class DeepseekPolicy(Policy): description=SubModuleReplacementDescription( suffix="embed_tokens", target_module=embedding_cls, - kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + kwargs={ + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + }, ), policy=policy, target_key="DeepseekModel", @@ -155,6 +162,7 @@ class DeepseekPolicy(Policy): "ep_group": self.shard_config.ep_group, "tp_group": self.shard_config.tensor_parallel_process_group, "moe_dp_group": self.shard_config.moe_dp_group, + "fp8_communication": self.shard_config.fp8_communication, }, ) ], @@ -305,7 +313,7 @@ class DeepseekForCausalLMPolicy(DeepseekPolicy): SubModuleReplacementDescription( suffix="lm_head", target_module=Linear1D_Col, - kwargs=dict(gather_output=True), + kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication), ) ] ) diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index 10df143c9..550b7f236 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -114,21 +114,27 @@ class MixtralPolicy(Policy): SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=Linear1D_Col, + kwargs={"fp8_communication": self.shard_config.fp8_communication}, ), SubModuleReplacementDescription( suffix="self_attn.k_proj", target_module=Linear1D_Col, + kwargs={"fp8_communication": self.shard_config.fp8_communication}, ), SubModuleReplacementDescription( suffix="self_attn.v_proj", target_module=Linear1D_Col, + kwargs={"fp8_communication": self.shard_config.fp8_communication}, ), SubModuleReplacementDescription( suffix="self_attn.o_proj", target_module=Linear1D_Row, + kwargs={"fp8_communication": self.shard_config.fp8_communication}, ), SubModuleReplacementDescription( # or replicate? - suffix="block_sparse_moe.gate", target_module=Linear1D_Col, kwargs={"gather_output": True} + suffix="block_sparse_moe.gate", + target_module=Linear1D_Col, + kwargs={"gather_output": True, "fp8_communication": self.shard_config.fp8_communication}, ), ], ) @@ -138,7 +144,10 @@ class MixtralPolicy(Policy): description=SubModuleReplacementDescription( suffix="embed_tokens", target_module=embedding_cls, - kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + kwargs={ + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + }, ), policy=policy, target_key=MixtralModel, @@ -155,6 +164,7 @@ class MixtralPolicy(Policy): "ep_group": self.shard_config.ep_group, "tp_group": self.shard_config.tensor_parallel_process_group, "moe_dp_group": self.shard_config.moe_dp_group, + "fp8_communication": self.shard_config.fp8_communication, }, ) ], @@ -282,7 +292,7 @@ class MixtralForCausalLMPolicy(MixtralPolicy): SubModuleReplacementDescription( suffix="lm_head", target_module=Linear1D_Col, - kwargs=dict(gather_output=True), + kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication), ) ] ) @@ -336,7 +346,9 @@ class MixtralForSequenceClassificationPolicy(MixtralPolicy): MixtralForSequenceClassification: ModulePolicyDescription( sub_module_replacement=[ SubModuleReplacementDescription( - suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + suffix="score", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication), ) ] )