Browse Source

[fp8]Moe support fp8 communication (#5977)

* fix

* support moe fp8

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

fix

fi

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/5998/head
flybird11111 3 months ago committed by GitHub
parent
commit
f1a3a326c4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 2
      colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
  2. 23
      colossalai/moe/_operation.py
  3. 23
      colossalai/quantization/fp8.py
  4. 4
      colossalai/shardformer/layer/embedding.py
  5. 69
      colossalai/shardformer/modeling/deepseek.py
  6. 59
      colossalai/shardformer/modeling/mixtral.py
  7. 12
      colossalai/shardformer/policies/deepseek.py
  8. 20
      colossalai/shardformer/policies/mixtral.py

2
colossalai/booster/plugin/moe_hybrid_parallel_plugin.py

@ -214,6 +214,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
moe_dp_outside: bool = True,
overlap_p2p: bool = True,
overlap_allgather: bool = False,
fp8_communication: bool = False,
) -> None:
if overlap_communication or zero_stage == 2:
overlap_communication = False
@ -341,6 +342,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
parallel_output=parallel_output,
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
gradient_checkpoint_config=gradient_checkpoint_config,
fp8_communication=fp8_communication,
)
self.amp_config = dict(
initial_scale=initial_scale,

23
colossalai/moe/_operation.py

@ -6,6 +6,8 @@ from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.distributed import ProcessGroup
from colossalai.quantization.fp8 import all_to_all_single_fp8
MOE_KERNEL = None
@ -380,6 +382,7 @@ def _all_to_all(
output_split_sizes: Optional[List[int]] = None,
group=None,
async_op: bool = False,
fp8_communication: bool = False,
):
"""
Returns:
@ -392,9 +395,14 @@ def _all_to_all(
outputs = torch.empty(outputs_shape, dtype=inputs.dtype, device=inputs.device)
inputs = inputs.contiguous()
outputs = outputs.contiguous()
handle = dist.all_to_all_single(
outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=async_op
)
if fp8_communication:
handle = all_to_all_single_fp8(
outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=False
)
else:
handle = dist.all_to_all_single(
outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=async_op
)
return outputs, handle
@ -407,6 +415,7 @@ class AllToAllUneven(torch.autograd.Function):
output_split_sizes=None,
group=None,
overlap: bool = False,
fp8_communication: bool = False,
):
"""
Returns:
@ -416,7 +425,9 @@ class AllToAllUneven(torch.autograd.Function):
ctx.input_split_sizes = input_split_sizes
ctx.output_split_sizes = output_split_sizes
ctx.group = group
return _all_to_all(inputs, input_split_sizes, output_split_sizes, group, overlap)
return _all_to_all(
inputs, input_split_sizes, output_split_sizes, group, overlap, fp8_communication=fp8_communication
)
@staticmethod
def backward(ctx: Any, *grad_outputs):
@ -426,6 +437,7 @@ class AllToAllUneven(torch.autograd.Function):
None,
None,
None,
None,
)
@ -435,8 +447,9 @@ def all_to_all_uneven(
output_split_sizes: Optional[List[int]] = None,
group=None,
overlap: bool = False,
fp8_communication: bool = False,
):
assert (
inputs.requires_grad
), "Input must require grad to assure that backward is executed, otherwise it might hang the program."
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap)
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap, fp8_communication)

23
colossalai/quantization/fp8.py

@ -27,16 +27,19 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
fp8_max = torch.finfo(fp8_type).max
if per_channel_scale:
per_channel_max = inp.abs().max(dim=-1).values.float()
per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0)
scale = fp8_max / per_channel_max[:, None]
scale_inv = per_channel_max / fp8_max
if inp.numel() == 0:
return inp.to(fp8_type), torch.tensor([1.0], device=inp.device)
else:
per_tensor_max = inp.abs().max().float()
per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0)
scale = fp8_max / per_tensor_max
scale_inv = 1.0 / scale
if per_channel_scale:
per_channel_max = inp.abs().max(dim=-1).values.float()
per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0)
scale = fp8_max / per_channel_max[:, None]
scale_inv = per_channel_max / fp8_max
else:
per_tensor_max = inp.abs().max().float()
per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0)
scale = fp8_max / per_tensor_max
scale_inv = 1.0 / scale
ret = (scale * inp.float()).to(fp8_type)
return ret, scale_inv
@ -113,7 +116,7 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, gro
tensor_list = [torch.empty_like(summed_out_fp8.view(torch.uint8)) for _ in range(world_size)]
dist.all_gather(tensor_list, summed_out_fp8.view(torch.uint8), group=group)
for i in range(world_size):
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i]
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i].to(input_device)
out = torch.cat(tensor_list, dim=0)
tensor.copy_(out[:input_size].view(input_shape).to(input_type))

4
colossalai/shardformer/layer/embedding.py

@ -274,6 +274,7 @@ class VocabParallelEmbedding1D(PaddingParallelModule):
weight: Optional[nn.Parameter] = None,
weight_initializer: Callable = init.normal_(),
make_vocab_size_divisible_by: int = 64,
fp8_communication: bool = False,
*args,
**kwargs,
):
@ -282,6 +283,7 @@ class VocabParallelEmbedding1D(PaddingParallelModule):
self.embed_args = args
self.embed_kwargs = kwargs
self.process_group = process_group
self.fp8_communication = fp8_communication
tensor_parallel_size = dist.get_world_size(group=process_group)
tensor_parallel_rank = dist.get_rank(group=process_group)
@ -390,5 +392,5 @@ class VocabParallelEmbedding1D(PaddingParallelModule):
embedding_output = output_parallel.clone()
embedding_output[input_mask, :] = 0.0
# Reduce across all the model parallel GPUs.
output = reduce_forward(embedding_output, self.process_group)
output = reduce_forward(embedding_output, self.process_group, fp8_communication=self.fp8_communication)
return output

69
colossalai/shardformer/modeling/deepseek.py

@ -24,6 +24,7 @@ from colossalai.moe._operation import (
all_to_all_uneven,
)
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.quantization.fp8 import all_reduce_fp8
from colossalai.shardformer.layer._operation import (
all_to_all_comm,
gather_forward_split_backward,
@ -61,7 +62,13 @@ class EPDeepseekMoE(nn.Module):
def __init__(self):
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup):
def setup_process_groups(
self,
tp_group: ProcessGroup,
moe_dp_group: ProcessGroup,
ep_group: ProcessGroup,
fp8_communication: bool = False,
):
assert tp_group is not None
assert moe_dp_group is not None
assert ep_group is not None
@ -70,6 +77,7 @@ class EPDeepseekMoE(nn.Module):
self.ep_rank = dist.get_rank(ep_group)
self.num_experts = self.config.n_routed_experts
assert self.num_experts % self.ep_size == 0
self.fp8_communication = fp8_communication
self.ep_group = ep_group
self.num_experts_per_ep = self.num_experts // self.ep_size
@ -86,9 +94,15 @@ class EPDeepseekMoE(nn.Module):
self.tp_group = tp_group
if self.tp_group.size() > 1:
for expert in held_experts:
expert.gate_proj = Linear1D_Col.from_native_module(expert.gate_proj, self.tp_group)
expert.up_proj = Linear1D_Col.from_native_module(expert.up_proj, self.tp_group)
expert.down_proj = Linear1D_Row.from_native_module(expert.down_proj, self.tp_group)
expert.gate_proj = Linear1D_Col.from_native_module(
expert.gate_proj, self.tp_group, fp8_communication=self.fp8_communication
)
expert.up_proj = Linear1D_Col.from_native_module(
expert.up_proj, self.tp_group, fp8_communication=self.fp8_communication
)
expert.down_proj = Linear1D_Row.from_native_module(
expert.down_proj, self.tp_group, fp8_communication=self.fp8_communication
)
for p in self.experts.parameters():
set_moe_tensor_ep_group(p, ep_group)
@ -106,7 +120,8 @@ class EPDeepseekMoE(nn.Module):
if module.__class__.__name__ == "DeepseekMLP":
return module
module.__class__ = EPDeepseekMoE
module.setup_process_groups(tp_group, moe_dp_group, ep_group)
fp8_communication = kwargs.get("fp8_communication", False)
module.setup_process_groups(tp_group, moe_dp_group, ep_group, fp8_communication=fp8_communication)
return module
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@ -137,11 +152,21 @@ class EPDeepseekMoE(nn.Module):
for i in range(1, self.ep_size):
activate_experts += output_split_sizes[i * self.num_experts_per_ep : (i + 1) * self.num_experts_per_ep]
activate_experts = (activate_experts > 0).float()
dist.all_reduce(activate_experts, group=self.moe_dp_group)
if self.fp8_communication:
all_reduce_fp8(activate_experts, group=self.moe_dp_group)
else:
dist.all_reduce(activate_experts, group=self.moe_dp_group)
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
output_states, _ = all_to_all_uneven(
dispatch_states,
input_split_list,
output_split_list,
self.ep_group,
fp8_communication=self.fp8_communication,
)
output_states = EPGradScalerIn.apply(output_states, self.ep_size)
if output_states.size(0) > 0:
@ -167,7 +192,9 @@ class EPDeepseekMoE(nn.Module):
output_states_list.append(split_states)
output_states = torch.cat(output_states_list)
output_states = EPGradScalerOut.apply(output_states, self.ep_size)
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
dispatch_states, _ = all_to_all_uneven(
output_states, output_split_list, input_split_list, self.ep_group, fp8_communication=self.fp8_communication
)
recover_token_idx = torch.empty_like(flat_topk_token_idx)
recover_token_idx[flat_topk_token_idx] = torch.arange(
flat_topk_token_idx.size(0), device=flat_topk_token_idx.device
@ -534,9 +561,9 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
query_states = all_to_all_comm(query_states, sp_group)
key_states = all_to_all_comm(key_states, sp_group)
value_states = all_to_all_comm(value_states, sp_group)
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
bsz, q_len, _ = query_states.size()
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
@ -595,7 +622,9 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128)
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) # (1, 4, 256)
attn_output = all_to_all_comm(
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
) # (1, 4, 256)
else:
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
@ -685,9 +714,13 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si
)
if sp_mode in ["ring", "split_gather"]:
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
inputs_embeds = split_forward_gather_backward(
inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all":
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
inputs_embeds = split_forward_gather_backward(
inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
)
# embed positions
hidden_states = inputs_embeds
@ -731,9 +764,13 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si
hidden_states = self.norm(hidden_states)
if sp_mode == "ring" or sp_mode == "split_gather":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
)
# add hidden states from the last decoder layer
if output_hidden_states:

59
colossalai/shardformer/modeling/mixtral.py

@ -53,7 +53,13 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
def __init__(self, *args, **kwargs):
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup):
def setup_process_groups(
self,
tp_group: ProcessGroup,
moe_dp_group: ProcessGroup,
ep_group: ProcessGroup,
fp8_communication: bool = False,
):
assert tp_group is not None
assert moe_dp_group is not None
assert ep_group is not None
@ -62,6 +68,7 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
self.ep_size = dist.get_world_size(ep_group)
self.ep_rank = dist.get_rank(ep_group)
self.ep_group = ep_group
self.fp8_communication = fp8_communication
if self.num_experts % self.ep_size != 0:
raise ValueError("The number of experts must be divisible by the number of expert parallel groups.")
@ -80,9 +87,15 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
self.tp_group = tp_group
if self.tp_group.size() > 1:
for expert in held_experts:
expert.w1 = Linear1D_Col.from_native_module(expert.w1, self.tp_group)
expert.w3 = Linear1D_Col.from_native_module(expert.w3, self.tp_group)
expert.w2 = Linear1D_Row.from_native_module(expert.w2, self.tp_group)
expert.w1 = Linear1D_Col.from_native_module(
expert.w1, self.tp_group, fp8_communication=self.fp8_communication
)
expert.w3 = Linear1D_Col.from_native_module(
expert.w3, self.tp_group, fp8_communication=self.fp8_communication
)
expert.w2 = Linear1D_Row.from_native_module(
expert.w2, self.tp_group, fp8_communication=self.fp8_communication
)
for p in self.experts.parameters():
set_moe_tensor_ep_group(p, ep_group)
@ -99,7 +112,8 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
# TODO: better init
LazyInitContext.materialize(module)
module.__class__ = EPMixtralSparseMoeBlock
module.setup_process_groups(tp_group, moe_dp_group, ep_group)
fp8_communication = kwargs.get("fp8_communication", False)
module.setup_process_groups(tp_group, moe_dp_group, ep_group, fp8_communication)
return module
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@ -120,6 +134,7 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
input_split_sizes = selected_experts.bincount(minlength=self.num_experts)
output_split_sizes = torch.zeros_like(input_split_sizes)
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
with torch.no_grad():
@ -132,7 +147,13 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
output_states, _ = all_to_all_uneven(
dispatch_states,
input_split_list,
output_split_list,
self.ep_group,
fp8_communication=self.fp8_communication,
)
# compute expert output
output_states = EPGradScalerIn.apply(output_states, self.ep_size)
if output_states.size(0) > 0:
@ -162,7 +183,9 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
output_states = torch.cat(output_states_list)
output_states = EPGradScalerOut.apply(output_states, self.ep_size)
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
dispatch_states, _ = all_to_all_uneven(
output_states, output_split_list, input_split_list, self.ep_group, fp8_communication=self.fp8_communication
)
recover_experts_idx = torch.empty_like(selected_experts_idx)
recover_experts_idx[selected_experts_idx] = torch.arange(
@ -566,9 +589,9 @@ def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
query_states = all_to_all_comm(query_states, sp_group)
key_states = all_to_all_comm(key_states, sp_group)
value_states = all_to_all_comm(value_states, sp_group)
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
bsz, q_len, _ = query_states.size()
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
@ -780,9 +803,13 @@ def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
)
if sp_mode in ["ring", "split_gather"]:
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
inputs_embeds = split_forward_gather_backward(
inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all":
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
inputs_embeds = split_forward_gather_backward(
inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
)
hidden_states = inputs_embeds
# decoder layers
@ -831,9 +858,13 @@ def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
hidden_states = self.norm(hidden_states)
if sp_mode == "ring" or sp_mode == "split_gather":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
)
# add hidden states from the last decoder layer
if output_hidden_states:

12
colossalai/shardformer/policies/deepseek.py

@ -118,18 +118,22 @@ class DeepseekPolicy(Policy):
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=Linear1D_Col,
kwargs={"fp8_communication": self.shard_config.fp8_communication},
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=Linear1D_Col,
kwargs={"fp8_communication": self.shard_config.fp8_communication},
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=Linear1D_Col,
kwargs={"fp8_communication": self.shard_config.fp8_communication},
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=Linear1D_Row,
kwargs={"fp8_communication": self.shard_config.fp8_communication},
),
],
)
@ -138,7 +142,10 @@ class DeepseekPolicy(Policy):
description=SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
kwargs={
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
},
),
policy=policy,
target_key="DeepseekModel",
@ -155,6 +162,7 @@ class DeepseekPolicy(Policy):
"ep_group": self.shard_config.ep_group,
"tp_group": self.shard_config.tensor_parallel_process_group,
"moe_dp_group": self.shard_config.moe_dp_group,
"fp8_communication": self.shard_config.fp8_communication,
},
)
],
@ -305,7 +313,7 @@ class DeepseekForCausalLMPolicy(DeepseekPolicy):
SubModuleReplacementDescription(
suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True),
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
)
]
)

20
colossalai/shardformer/policies/mixtral.py

@ -114,21 +114,27 @@ class MixtralPolicy(Policy):
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=Linear1D_Col,
kwargs={"fp8_communication": self.shard_config.fp8_communication},
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=Linear1D_Col,
kwargs={"fp8_communication": self.shard_config.fp8_communication},
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=Linear1D_Col,
kwargs={"fp8_communication": self.shard_config.fp8_communication},
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=Linear1D_Row,
kwargs={"fp8_communication": self.shard_config.fp8_communication},
),
SubModuleReplacementDescription( # or replicate?
suffix="block_sparse_moe.gate", target_module=Linear1D_Col, kwargs={"gather_output": True}
suffix="block_sparse_moe.gate",
target_module=Linear1D_Col,
kwargs={"gather_output": True, "fp8_communication": self.shard_config.fp8_communication},
),
],
)
@ -138,7 +144,10 @@ class MixtralPolicy(Policy):
description=SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
kwargs={
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
},
),
policy=policy,
target_key=MixtralModel,
@ -155,6 +164,7 @@ class MixtralPolicy(Policy):
"ep_group": self.shard_config.ep_group,
"tp_group": self.shard_config.tensor_parallel_process_group,
"moe_dp_group": self.shard_config.moe_dp_group,
"fp8_communication": self.shard_config.fp8_communication,
},
)
],
@ -282,7 +292,7 @@ class MixtralForCausalLMPolicy(MixtralPolicy):
SubModuleReplacementDescription(
suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True),
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
)
]
)
@ -336,7 +346,9 @@ class MixtralForSequenceClassificationPolicy(MixtralPolicy):
MixtralForSequenceClassification: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
suffix="score",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
)
]
)

Loading…
Cancel
Save