mirror of https://github.com/hpcaitech/ColossalAI
[fp8]Moe support fp8 communication (#5977)
* fix * support moe fp8 * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix fix fi * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/5998/head
parent
e4aadeee20
commit
f1a3a326c4
|
@ -214,6 +214,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
moe_dp_outside: bool = True,
|
moe_dp_outside: bool = True,
|
||||||
overlap_p2p: bool = True,
|
overlap_p2p: bool = True,
|
||||||
overlap_allgather: bool = False,
|
overlap_allgather: bool = False,
|
||||||
|
fp8_communication: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
if overlap_communication or zero_stage == 2:
|
if overlap_communication or zero_stage == 2:
|
||||||
overlap_communication = False
|
overlap_communication = False
|
||||||
|
@ -341,6 +342,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
parallel_output=parallel_output,
|
parallel_output=parallel_output,
|
||||||
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
|
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
|
||||||
gradient_checkpoint_config=gradient_checkpoint_config,
|
gradient_checkpoint_config=gradient_checkpoint_config,
|
||||||
|
fp8_communication=fp8_communication,
|
||||||
)
|
)
|
||||||
self.amp_config = dict(
|
self.amp_config = dict(
|
||||||
initial_scale=initial_scale,
|
initial_scale=initial_scale,
|
||||||
|
|
|
@ -6,6 +6,8 @@ from torch import Tensor
|
||||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
from colossalai.quantization.fp8 import all_to_all_single_fp8
|
||||||
|
|
||||||
MOE_KERNEL = None
|
MOE_KERNEL = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -380,6 +382,7 @@ def _all_to_all(
|
||||||
output_split_sizes: Optional[List[int]] = None,
|
output_split_sizes: Optional[List[int]] = None,
|
||||||
group=None,
|
group=None,
|
||||||
async_op: bool = False,
|
async_op: bool = False,
|
||||||
|
fp8_communication: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -392,9 +395,14 @@ def _all_to_all(
|
||||||
outputs = torch.empty(outputs_shape, dtype=inputs.dtype, device=inputs.device)
|
outputs = torch.empty(outputs_shape, dtype=inputs.dtype, device=inputs.device)
|
||||||
inputs = inputs.contiguous()
|
inputs = inputs.contiguous()
|
||||||
outputs = outputs.contiguous()
|
outputs = outputs.contiguous()
|
||||||
handle = dist.all_to_all_single(
|
if fp8_communication:
|
||||||
outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=async_op
|
handle = all_to_all_single_fp8(
|
||||||
)
|
outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
handle = dist.all_to_all_single(
|
||||||
|
outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=async_op
|
||||||
|
)
|
||||||
return outputs, handle
|
return outputs, handle
|
||||||
|
|
||||||
|
|
||||||
|
@ -407,6 +415,7 @@ class AllToAllUneven(torch.autograd.Function):
|
||||||
output_split_sizes=None,
|
output_split_sizes=None,
|
||||||
group=None,
|
group=None,
|
||||||
overlap: bool = False,
|
overlap: bool = False,
|
||||||
|
fp8_communication: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -416,7 +425,9 @@ class AllToAllUneven(torch.autograd.Function):
|
||||||
ctx.input_split_sizes = input_split_sizes
|
ctx.input_split_sizes = input_split_sizes
|
||||||
ctx.output_split_sizes = output_split_sizes
|
ctx.output_split_sizes = output_split_sizes
|
||||||
ctx.group = group
|
ctx.group = group
|
||||||
return _all_to_all(inputs, input_split_sizes, output_split_sizes, group, overlap)
|
return _all_to_all(
|
||||||
|
inputs, input_split_sizes, output_split_sizes, group, overlap, fp8_communication=fp8_communication
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx: Any, *grad_outputs):
|
def backward(ctx: Any, *grad_outputs):
|
||||||
|
@ -426,6 +437,7 @@ class AllToAllUneven(torch.autograd.Function):
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -435,8 +447,9 @@ def all_to_all_uneven(
|
||||||
output_split_sizes: Optional[List[int]] = None,
|
output_split_sizes: Optional[List[int]] = None,
|
||||||
group=None,
|
group=None,
|
||||||
overlap: bool = False,
|
overlap: bool = False,
|
||||||
|
fp8_communication: bool = False,
|
||||||
):
|
):
|
||||||
assert (
|
assert (
|
||||||
inputs.requires_grad
|
inputs.requires_grad
|
||||||
), "Input must require grad to assure that backward is executed, otherwise it might hang the program."
|
), "Input must require grad to assure that backward is executed, otherwise it might hang the program."
|
||||||
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap)
|
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap, fp8_communication)
|
||||||
|
|
|
@ -27,16 +27,19 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -
|
||||||
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
||||||
fp8_max = torch.finfo(fp8_type).max
|
fp8_max = torch.finfo(fp8_type).max
|
||||||
|
|
||||||
if per_channel_scale:
|
if inp.numel() == 0:
|
||||||
per_channel_max = inp.abs().max(dim=-1).values.float()
|
return inp.to(fp8_type), torch.tensor([1.0], device=inp.device)
|
||||||
per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0)
|
|
||||||
scale = fp8_max / per_channel_max[:, None]
|
|
||||||
scale_inv = per_channel_max / fp8_max
|
|
||||||
else:
|
else:
|
||||||
per_tensor_max = inp.abs().max().float()
|
if per_channel_scale:
|
||||||
per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0)
|
per_channel_max = inp.abs().max(dim=-1).values.float()
|
||||||
scale = fp8_max / per_tensor_max
|
per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0)
|
||||||
scale_inv = 1.0 / scale
|
scale = fp8_max / per_channel_max[:, None]
|
||||||
|
scale_inv = per_channel_max / fp8_max
|
||||||
|
else:
|
||||||
|
per_tensor_max = inp.abs().max().float()
|
||||||
|
per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0)
|
||||||
|
scale = fp8_max / per_tensor_max
|
||||||
|
scale_inv = 1.0 / scale
|
||||||
|
|
||||||
ret = (scale * inp.float()).to(fp8_type)
|
ret = (scale * inp.float()).to(fp8_type)
|
||||||
return ret, scale_inv
|
return ret, scale_inv
|
||||||
|
@ -113,7 +116,7 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, gro
|
||||||
tensor_list = [torch.empty_like(summed_out_fp8.view(torch.uint8)) for _ in range(world_size)]
|
tensor_list = [torch.empty_like(summed_out_fp8.view(torch.uint8)) for _ in range(world_size)]
|
||||||
dist.all_gather(tensor_list, summed_out_fp8.view(torch.uint8), group=group)
|
dist.all_gather(tensor_list, summed_out_fp8.view(torch.uint8), group=group)
|
||||||
for i in range(world_size):
|
for i in range(world_size):
|
||||||
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i]
|
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i].to(input_device)
|
||||||
out = torch.cat(tensor_list, dim=0)
|
out = torch.cat(tensor_list, dim=0)
|
||||||
tensor.copy_(out[:input_size].view(input_shape).to(input_type))
|
tensor.copy_(out[:input_size].view(input_shape).to(input_type))
|
||||||
|
|
||||||
|
|
|
@ -274,6 +274,7 @@ class VocabParallelEmbedding1D(PaddingParallelModule):
|
||||||
weight: Optional[nn.Parameter] = None,
|
weight: Optional[nn.Parameter] = None,
|
||||||
weight_initializer: Callable = init.normal_(),
|
weight_initializer: Callable = init.normal_(),
|
||||||
make_vocab_size_divisible_by: int = 64,
|
make_vocab_size_divisible_by: int = 64,
|
||||||
|
fp8_communication: bool = False,
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
@ -282,6 +283,7 @@ class VocabParallelEmbedding1D(PaddingParallelModule):
|
||||||
self.embed_args = args
|
self.embed_args = args
|
||||||
self.embed_kwargs = kwargs
|
self.embed_kwargs = kwargs
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
|
self.fp8_communication = fp8_communication
|
||||||
|
|
||||||
tensor_parallel_size = dist.get_world_size(group=process_group)
|
tensor_parallel_size = dist.get_world_size(group=process_group)
|
||||||
tensor_parallel_rank = dist.get_rank(group=process_group)
|
tensor_parallel_rank = dist.get_rank(group=process_group)
|
||||||
|
@ -390,5 +392,5 @@ class VocabParallelEmbedding1D(PaddingParallelModule):
|
||||||
embedding_output = output_parallel.clone()
|
embedding_output = output_parallel.clone()
|
||||||
embedding_output[input_mask, :] = 0.0
|
embedding_output[input_mask, :] = 0.0
|
||||||
# Reduce across all the model parallel GPUs.
|
# Reduce across all the model parallel GPUs.
|
||||||
output = reduce_forward(embedding_output, self.process_group)
|
output = reduce_forward(embedding_output, self.process_group, fp8_communication=self.fp8_communication)
|
||||||
return output
|
return output
|
||||||
|
|
|
@ -24,6 +24,7 @@ from colossalai.moe._operation import (
|
||||||
all_to_all_uneven,
|
all_to_all_uneven,
|
||||||
)
|
)
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
from colossalai.quantization.fp8 import all_reduce_fp8
|
||||||
from colossalai.shardformer.layer._operation import (
|
from colossalai.shardformer.layer._operation import (
|
||||||
all_to_all_comm,
|
all_to_all_comm,
|
||||||
gather_forward_split_backward,
|
gather_forward_split_backward,
|
||||||
|
@ -61,7 +62,13 @@ class EPDeepseekMoE(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
|
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
|
||||||
|
|
||||||
def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup):
|
def setup_process_groups(
|
||||||
|
self,
|
||||||
|
tp_group: ProcessGroup,
|
||||||
|
moe_dp_group: ProcessGroup,
|
||||||
|
ep_group: ProcessGroup,
|
||||||
|
fp8_communication: bool = False,
|
||||||
|
):
|
||||||
assert tp_group is not None
|
assert tp_group is not None
|
||||||
assert moe_dp_group is not None
|
assert moe_dp_group is not None
|
||||||
assert ep_group is not None
|
assert ep_group is not None
|
||||||
|
@ -70,6 +77,7 @@ class EPDeepseekMoE(nn.Module):
|
||||||
self.ep_rank = dist.get_rank(ep_group)
|
self.ep_rank = dist.get_rank(ep_group)
|
||||||
self.num_experts = self.config.n_routed_experts
|
self.num_experts = self.config.n_routed_experts
|
||||||
assert self.num_experts % self.ep_size == 0
|
assert self.num_experts % self.ep_size == 0
|
||||||
|
self.fp8_communication = fp8_communication
|
||||||
|
|
||||||
self.ep_group = ep_group
|
self.ep_group = ep_group
|
||||||
self.num_experts_per_ep = self.num_experts // self.ep_size
|
self.num_experts_per_ep = self.num_experts // self.ep_size
|
||||||
|
@ -86,9 +94,15 @@ class EPDeepseekMoE(nn.Module):
|
||||||
self.tp_group = tp_group
|
self.tp_group = tp_group
|
||||||
if self.tp_group.size() > 1:
|
if self.tp_group.size() > 1:
|
||||||
for expert in held_experts:
|
for expert in held_experts:
|
||||||
expert.gate_proj = Linear1D_Col.from_native_module(expert.gate_proj, self.tp_group)
|
expert.gate_proj = Linear1D_Col.from_native_module(
|
||||||
expert.up_proj = Linear1D_Col.from_native_module(expert.up_proj, self.tp_group)
|
expert.gate_proj, self.tp_group, fp8_communication=self.fp8_communication
|
||||||
expert.down_proj = Linear1D_Row.from_native_module(expert.down_proj, self.tp_group)
|
)
|
||||||
|
expert.up_proj = Linear1D_Col.from_native_module(
|
||||||
|
expert.up_proj, self.tp_group, fp8_communication=self.fp8_communication
|
||||||
|
)
|
||||||
|
expert.down_proj = Linear1D_Row.from_native_module(
|
||||||
|
expert.down_proj, self.tp_group, fp8_communication=self.fp8_communication
|
||||||
|
)
|
||||||
|
|
||||||
for p in self.experts.parameters():
|
for p in self.experts.parameters():
|
||||||
set_moe_tensor_ep_group(p, ep_group)
|
set_moe_tensor_ep_group(p, ep_group)
|
||||||
|
@ -106,7 +120,8 @@ class EPDeepseekMoE(nn.Module):
|
||||||
if module.__class__.__name__ == "DeepseekMLP":
|
if module.__class__.__name__ == "DeepseekMLP":
|
||||||
return module
|
return module
|
||||||
module.__class__ = EPDeepseekMoE
|
module.__class__ = EPDeepseekMoE
|
||||||
module.setup_process_groups(tp_group, moe_dp_group, ep_group)
|
fp8_communication = kwargs.get("fp8_communication", False)
|
||||||
|
module.setup_process_groups(tp_group, moe_dp_group, ep_group, fp8_communication=fp8_communication)
|
||||||
return module
|
return module
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
@ -137,11 +152,21 @@ class EPDeepseekMoE(nn.Module):
|
||||||
for i in range(1, self.ep_size):
|
for i in range(1, self.ep_size):
|
||||||
activate_experts += output_split_sizes[i * self.num_experts_per_ep : (i + 1) * self.num_experts_per_ep]
|
activate_experts += output_split_sizes[i * self.num_experts_per_ep : (i + 1) * self.num_experts_per_ep]
|
||||||
activate_experts = (activate_experts > 0).float()
|
activate_experts = (activate_experts > 0).float()
|
||||||
dist.all_reduce(activate_experts, group=self.moe_dp_group)
|
|
||||||
|
if self.fp8_communication:
|
||||||
|
all_reduce_fp8(activate_experts, group=self.moe_dp_group)
|
||||||
|
else:
|
||||||
|
dist.all_reduce(activate_experts, group=self.moe_dp_group)
|
||||||
|
|
||||||
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||||
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||||
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
|
output_states, _ = all_to_all_uneven(
|
||||||
|
dispatch_states,
|
||||||
|
input_split_list,
|
||||||
|
output_split_list,
|
||||||
|
self.ep_group,
|
||||||
|
fp8_communication=self.fp8_communication,
|
||||||
|
)
|
||||||
output_states = EPGradScalerIn.apply(output_states, self.ep_size)
|
output_states = EPGradScalerIn.apply(output_states, self.ep_size)
|
||||||
|
|
||||||
if output_states.size(0) > 0:
|
if output_states.size(0) > 0:
|
||||||
|
@ -167,7 +192,9 @@ class EPDeepseekMoE(nn.Module):
|
||||||
output_states_list.append(split_states)
|
output_states_list.append(split_states)
|
||||||
output_states = torch.cat(output_states_list)
|
output_states = torch.cat(output_states_list)
|
||||||
output_states = EPGradScalerOut.apply(output_states, self.ep_size)
|
output_states = EPGradScalerOut.apply(output_states, self.ep_size)
|
||||||
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
|
dispatch_states, _ = all_to_all_uneven(
|
||||||
|
output_states, output_split_list, input_split_list, self.ep_group, fp8_communication=self.fp8_communication
|
||||||
|
)
|
||||||
recover_token_idx = torch.empty_like(flat_topk_token_idx)
|
recover_token_idx = torch.empty_like(flat_topk_token_idx)
|
||||||
recover_token_idx[flat_topk_token_idx] = torch.arange(
|
recover_token_idx[flat_topk_token_idx] = torch.arange(
|
||||||
flat_topk_token_idx.size(0), device=flat_topk_token_idx.device
|
flat_topk_token_idx.size(0), device=flat_topk_token_idx.device
|
||||||
|
@ -534,9 +561,9 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non
|
||||||
|
|
||||||
# sp: all-to-all comminucation when introducing sequence parallel
|
# sp: all-to-all comminucation when introducing sequence parallel
|
||||||
if sp_mode == "all_to_all":
|
if sp_mode == "all_to_all":
|
||||||
query_states = all_to_all_comm(query_states, sp_group)
|
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||||
key_states = all_to_all_comm(key_states, sp_group)
|
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||||
value_states = all_to_all_comm(value_states, sp_group)
|
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||||
bsz, q_len, _ = query_states.size()
|
bsz, q_len, _ = query_states.size()
|
||||||
# Flash attention requires the input to have the shape
|
# Flash attention requires the input to have the shape
|
||||||
# batch_size x seq_length x head_dim x hidden_dim
|
# batch_size x seq_length x head_dim x hidden_dim
|
||||||
|
@ -595,7 +622,9 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non
|
||||||
# sp: all-to-all comminucation when introducing sequence parallel
|
# sp: all-to-all comminucation when introducing sequence parallel
|
||||||
if sp_mode == "all_to_all":
|
if sp_mode == "all_to_all":
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128)
|
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128)
|
||||||
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) # (1, 4, 256)
|
attn_output = all_to_all_comm(
|
||||||
|
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
|
||||||
|
) # (1, 4, 256)
|
||||||
else:
|
else:
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
|
||||||
|
@ -685,9 +714,13 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si
|
||||||
)
|
)
|
||||||
|
|
||||||
if sp_mode in ["ring", "split_gather"]:
|
if sp_mode in ["ring", "split_gather"]:
|
||||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
|
inputs_embeds = split_forward_gather_backward(
|
||||||
|
inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||||
|
)
|
||||||
elif sp_mode == "all_to_all":
|
elif sp_mode == "all_to_all":
|
||||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
|
inputs_embeds = split_forward_gather_backward(
|
||||||
|
inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
|
||||||
|
)
|
||||||
# embed positions
|
# embed positions
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
@ -731,9 +764,13 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si
|
||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
if sp_mode == "ring" or sp_mode == "split_gather":
|
if sp_mode == "ring" or sp_mode == "split_gather":
|
||||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
|
hidden_states = gather_forward_split_backward(
|
||||||
|
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||||
|
)
|
||||||
elif sp_mode == "all_to_all":
|
elif sp_mode == "all_to_all":
|
||||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
|
hidden_states = gather_forward_split_backward(
|
||||||
|
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
|
||||||
|
)
|
||||||
|
|
||||||
# add hidden states from the last decoder layer
|
# add hidden states from the last decoder layer
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
|
|
|
@ -53,7 +53,13 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
|
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
|
||||||
|
|
||||||
def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup):
|
def setup_process_groups(
|
||||||
|
self,
|
||||||
|
tp_group: ProcessGroup,
|
||||||
|
moe_dp_group: ProcessGroup,
|
||||||
|
ep_group: ProcessGroup,
|
||||||
|
fp8_communication: bool = False,
|
||||||
|
):
|
||||||
assert tp_group is not None
|
assert tp_group is not None
|
||||||
assert moe_dp_group is not None
|
assert moe_dp_group is not None
|
||||||
assert ep_group is not None
|
assert ep_group is not None
|
||||||
|
@ -62,6 +68,7 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
||||||
self.ep_size = dist.get_world_size(ep_group)
|
self.ep_size = dist.get_world_size(ep_group)
|
||||||
self.ep_rank = dist.get_rank(ep_group)
|
self.ep_rank = dist.get_rank(ep_group)
|
||||||
self.ep_group = ep_group
|
self.ep_group = ep_group
|
||||||
|
self.fp8_communication = fp8_communication
|
||||||
|
|
||||||
if self.num_experts % self.ep_size != 0:
|
if self.num_experts % self.ep_size != 0:
|
||||||
raise ValueError("The number of experts must be divisible by the number of expert parallel groups.")
|
raise ValueError("The number of experts must be divisible by the number of expert parallel groups.")
|
||||||
|
@ -80,9 +87,15 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
||||||
self.tp_group = tp_group
|
self.tp_group = tp_group
|
||||||
if self.tp_group.size() > 1:
|
if self.tp_group.size() > 1:
|
||||||
for expert in held_experts:
|
for expert in held_experts:
|
||||||
expert.w1 = Linear1D_Col.from_native_module(expert.w1, self.tp_group)
|
expert.w1 = Linear1D_Col.from_native_module(
|
||||||
expert.w3 = Linear1D_Col.from_native_module(expert.w3, self.tp_group)
|
expert.w1, self.tp_group, fp8_communication=self.fp8_communication
|
||||||
expert.w2 = Linear1D_Row.from_native_module(expert.w2, self.tp_group)
|
)
|
||||||
|
expert.w3 = Linear1D_Col.from_native_module(
|
||||||
|
expert.w3, self.tp_group, fp8_communication=self.fp8_communication
|
||||||
|
)
|
||||||
|
expert.w2 = Linear1D_Row.from_native_module(
|
||||||
|
expert.w2, self.tp_group, fp8_communication=self.fp8_communication
|
||||||
|
)
|
||||||
|
|
||||||
for p in self.experts.parameters():
|
for p in self.experts.parameters():
|
||||||
set_moe_tensor_ep_group(p, ep_group)
|
set_moe_tensor_ep_group(p, ep_group)
|
||||||
|
@ -99,7 +112,8 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
||||||
# TODO: better init
|
# TODO: better init
|
||||||
LazyInitContext.materialize(module)
|
LazyInitContext.materialize(module)
|
||||||
module.__class__ = EPMixtralSparseMoeBlock
|
module.__class__ = EPMixtralSparseMoeBlock
|
||||||
module.setup_process_groups(tp_group, moe_dp_group, ep_group)
|
fp8_communication = kwargs.get("fp8_communication", False)
|
||||||
|
module.setup_process_groups(tp_group, moe_dp_group, ep_group, fp8_communication)
|
||||||
return module
|
return module
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
@ -120,6 +134,7 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
||||||
input_split_sizes = selected_experts.bincount(minlength=self.num_experts)
|
input_split_sizes = selected_experts.bincount(minlength=self.num_experts)
|
||||||
|
|
||||||
output_split_sizes = torch.zeros_like(input_split_sizes)
|
output_split_sizes = torch.zeros_like(input_split_sizes)
|
||||||
|
|
||||||
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
|
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
@ -132,7 +147,13 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
||||||
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||||
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||||
|
|
||||||
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
|
output_states, _ = all_to_all_uneven(
|
||||||
|
dispatch_states,
|
||||||
|
input_split_list,
|
||||||
|
output_split_list,
|
||||||
|
self.ep_group,
|
||||||
|
fp8_communication=self.fp8_communication,
|
||||||
|
)
|
||||||
# compute expert output
|
# compute expert output
|
||||||
output_states = EPGradScalerIn.apply(output_states, self.ep_size)
|
output_states = EPGradScalerIn.apply(output_states, self.ep_size)
|
||||||
if output_states.size(0) > 0:
|
if output_states.size(0) > 0:
|
||||||
|
@ -162,7 +183,9 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
||||||
output_states = torch.cat(output_states_list)
|
output_states = torch.cat(output_states_list)
|
||||||
|
|
||||||
output_states = EPGradScalerOut.apply(output_states, self.ep_size)
|
output_states = EPGradScalerOut.apply(output_states, self.ep_size)
|
||||||
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
|
dispatch_states, _ = all_to_all_uneven(
|
||||||
|
output_states, output_split_list, input_split_list, self.ep_group, fp8_communication=self.fp8_communication
|
||||||
|
)
|
||||||
|
|
||||||
recover_experts_idx = torch.empty_like(selected_experts_idx)
|
recover_experts_idx = torch.empty_like(selected_experts_idx)
|
||||||
recover_experts_idx[selected_experts_idx] = torch.arange(
|
recover_experts_idx[selected_experts_idx] = torch.arange(
|
||||||
|
@ -566,9 +589,9 @@ def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None
|
||||||
|
|
||||||
# sp: all-to-all comminucation when introducing sequence parallel
|
# sp: all-to-all comminucation when introducing sequence parallel
|
||||||
if sp_mode == "all_to_all":
|
if sp_mode == "all_to_all":
|
||||||
query_states = all_to_all_comm(query_states, sp_group)
|
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||||
key_states = all_to_all_comm(key_states, sp_group)
|
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||||
value_states = all_to_all_comm(value_states, sp_group)
|
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||||
bsz, q_len, _ = query_states.size()
|
bsz, q_len, _ = query_states.size()
|
||||||
|
|
||||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
@ -780,9 +803,13 @@ def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
|
||||||
)
|
)
|
||||||
|
|
||||||
if sp_mode in ["ring", "split_gather"]:
|
if sp_mode in ["ring", "split_gather"]:
|
||||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
|
inputs_embeds = split_forward_gather_backward(
|
||||||
|
inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||||
|
)
|
||||||
elif sp_mode == "all_to_all":
|
elif sp_mode == "all_to_all":
|
||||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
|
inputs_embeds = split_forward_gather_backward(
|
||||||
|
inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
|
||||||
|
)
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
# decoder layers
|
# decoder layers
|
||||||
|
@ -831,9 +858,13 @@ def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
|
||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
if sp_mode == "ring" or sp_mode == "split_gather":
|
if sp_mode == "ring" or sp_mode == "split_gather":
|
||||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
|
hidden_states = gather_forward_split_backward(
|
||||||
|
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||||
|
)
|
||||||
elif sp_mode == "all_to_all":
|
elif sp_mode == "all_to_all":
|
||||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
|
hidden_states = gather_forward_split_backward(
|
||||||
|
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
|
||||||
|
)
|
||||||
|
|
||||||
# add hidden states from the last decoder layer
|
# add hidden states from the last decoder layer
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
|
|
|
@ -118,18 +118,22 @@ class DeepseekPolicy(Policy):
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.q_proj",
|
suffix="self_attn.q_proj",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
|
kwargs={"fp8_communication": self.shard_config.fp8_communication},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.k_proj",
|
suffix="self_attn.k_proj",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
|
kwargs={"fp8_communication": self.shard_config.fp8_communication},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.v_proj",
|
suffix="self_attn.v_proj",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
|
kwargs={"fp8_communication": self.shard_config.fp8_communication},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.o_proj",
|
suffix="self_attn.o_proj",
|
||||||
target_module=Linear1D_Row,
|
target_module=Linear1D_Row,
|
||||||
|
kwargs={"fp8_communication": self.shard_config.fp8_communication},
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -138,7 +142,10 @@ class DeepseekPolicy(Policy):
|
||||||
description=SubModuleReplacementDescription(
|
description=SubModuleReplacementDescription(
|
||||||
suffix="embed_tokens",
|
suffix="embed_tokens",
|
||||||
target_module=embedding_cls,
|
target_module=embedding_cls,
|
||||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
kwargs={
|
||||||
|
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||||
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
},
|
||||||
),
|
),
|
||||||
policy=policy,
|
policy=policy,
|
||||||
target_key="DeepseekModel",
|
target_key="DeepseekModel",
|
||||||
|
@ -155,6 +162,7 @@ class DeepseekPolicy(Policy):
|
||||||
"ep_group": self.shard_config.ep_group,
|
"ep_group": self.shard_config.ep_group,
|
||||||
"tp_group": self.shard_config.tensor_parallel_process_group,
|
"tp_group": self.shard_config.tensor_parallel_process_group,
|
||||||
"moe_dp_group": self.shard_config.moe_dp_group,
|
"moe_dp_group": self.shard_config.moe_dp_group,
|
||||||
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
|
@ -305,7 +313,7 @@ class DeepseekForCausalLMPolicy(DeepseekPolicy):
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="lm_head",
|
suffix="lm_head",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
kwargs=dict(gather_output=True),
|
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
|
@ -114,21 +114,27 @@ class MixtralPolicy(Policy):
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.q_proj",
|
suffix="self_attn.q_proj",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
|
kwargs={"fp8_communication": self.shard_config.fp8_communication},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.k_proj",
|
suffix="self_attn.k_proj",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
|
kwargs={"fp8_communication": self.shard_config.fp8_communication},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.v_proj",
|
suffix="self_attn.v_proj",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
|
kwargs={"fp8_communication": self.shard_config.fp8_communication},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.o_proj",
|
suffix="self_attn.o_proj",
|
||||||
target_module=Linear1D_Row,
|
target_module=Linear1D_Row,
|
||||||
|
kwargs={"fp8_communication": self.shard_config.fp8_communication},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription( # or replicate?
|
SubModuleReplacementDescription( # or replicate?
|
||||||
suffix="block_sparse_moe.gate", target_module=Linear1D_Col, kwargs={"gather_output": True}
|
suffix="block_sparse_moe.gate",
|
||||||
|
target_module=Linear1D_Col,
|
||||||
|
kwargs={"gather_output": True, "fp8_communication": self.shard_config.fp8_communication},
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -138,7 +144,10 @@ class MixtralPolicy(Policy):
|
||||||
description=SubModuleReplacementDescription(
|
description=SubModuleReplacementDescription(
|
||||||
suffix="embed_tokens",
|
suffix="embed_tokens",
|
||||||
target_module=embedding_cls,
|
target_module=embedding_cls,
|
||||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
kwargs={
|
||||||
|
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||||
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
},
|
||||||
),
|
),
|
||||||
policy=policy,
|
policy=policy,
|
||||||
target_key=MixtralModel,
|
target_key=MixtralModel,
|
||||||
|
@ -155,6 +164,7 @@ class MixtralPolicy(Policy):
|
||||||
"ep_group": self.shard_config.ep_group,
|
"ep_group": self.shard_config.ep_group,
|
||||||
"tp_group": self.shard_config.tensor_parallel_process_group,
|
"tp_group": self.shard_config.tensor_parallel_process_group,
|
||||||
"moe_dp_group": self.shard_config.moe_dp_group,
|
"moe_dp_group": self.shard_config.moe_dp_group,
|
||||||
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
|
@ -282,7 +292,7 @@ class MixtralForCausalLMPolicy(MixtralPolicy):
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="lm_head",
|
suffix="lm_head",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
kwargs=dict(gather_output=True),
|
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -336,7 +346,9 @@ class MixtralForSequenceClassificationPolicy(MixtralPolicy):
|
||||||
MixtralForSequenceClassification: ModulePolicyDescription(
|
MixtralForSequenceClassification: ModulePolicyDescription(
|
||||||
sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
|
suffix="score",
|
||||||
|
target_module=Linear1D_Col,
|
||||||
|
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue