mirror of https://github.com/InternLM/InternLM
add share embedding weight support for moe
parent
eeef07934a
commit
bf6dbf07fa
|
@ -129,6 +129,7 @@ model = dict(
|
||||||
checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
|
checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
|
||||||
num_attention_heads=NUM_ATTENTION_HEAD,
|
num_attention_heads=NUM_ATTENTION_HEAD,
|
||||||
embed_split_hidden=True,
|
embed_split_hidden=True,
|
||||||
|
tie_embeddings_and_output_weights=False,
|
||||||
vocab_size=VOCAB_SIZE,
|
vocab_size=VOCAB_SIZE,
|
||||||
embed_grad_scale=1,
|
embed_grad_scale=1,
|
||||||
parallel_output=True,
|
parallel_output=True,
|
||||||
|
|
|
@ -48,6 +48,9 @@ class ParallelMode(Enum):
|
||||||
# expert data parallel
|
# expert data parallel
|
||||||
EXPERT_DATA = "expert_data"
|
EXPERT_DATA = "expert_data"
|
||||||
|
|
||||||
|
# embedding share
|
||||||
|
EMBEDDING = "embedding"
|
||||||
|
|
||||||
# dummy mode, only used during mode construction
|
# dummy mode, only used during mode construction
|
||||||
DUMMY = "dummy"
|
DUMMY = "dummy"
|
||||||
|
|
||||||
|
@ -236,8 +239,8 @@ class Initializer_Pipeline(ProcessGroupInitializer):
|
||||||
process_group = None
|
process_group = None
|
||||||
cpu_group = None
|
cpu_group = None
|
||||||
group_world_size = None
|
group_world_size = None
|
||||||
mode = ParallelMode.PIPELINE
|
|
||||||
|
|
||||||
|
groups = []
|
||||||
for i in range(self.data_parallel_size):
|
for i in range(self.data_parallel_size):
|
||||||
for j in range(self.pipeline_stage_size):
|
for j in range(self.pipeline_stage_size):
|
||||||
ranks = list(
|
ranks = list(
|
||||||
|
@ -265,7 +268,37 @@ class Initializer_Pipeline(ProcessGroupInitializer):
|
||||||
cpu_group = group_cpu
|
cpu_group = group_cpu
|
||||||
ranks_in_group = ranks
|
ranks_in_group = ranks
|
||||||
|
|
||||||
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
|
groups.append(
|
||||||
|
(local_rank, group_world_size, process_group, cpu_group, ranks_in_group, ParallelMode.PIPELINE)
|
||||||
|
)
|
||||||
|
|
||||||
|
# create embedding commuication group
|
||||||
|
if len(ranks) > 1:
|
||||||
|
embedding_ranks = [ranks[0], ranks[-1]]
|
||||||
|
else:
|
||||||
|
embedding_ranks = ranks
|
||||||
|
embed_group = dist.new_group(embedding_ranks, timeout=LLM_NCCL_TIMEOUT)
|
||||||
|
if use_cpu:
|
||||||
|
group_cpu = (
|
||||||
|
dist.new_group(embedding_ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
|
||||||
|
if dist.get_backend() != "gloo"
|
||||||
|
else embed_group
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
group_cpu = None
|
||||||
|
|
||||||
|
if self.rank in ranks:
|
||||||
|
local_rank = ranks.index(self.rank)
|
||||||
|
group_world_size = len(embedding_ranks)
|
||||||
|
process_group = embed_group
|
||||||
|
cpu_group = group_cpu
|
||||||
|
ranks_in_group = embedding_ranks
|
||||||
|
|
||||||
|
groups.append(
|
||||||
|
(local_rank, group_world_size, process_group, cpu_group, ranks_in_group, ParallelMode.EMBEDDING)
|
||||||
|
)
|
||||||
|
|
||||||
|
return groups
|
||||||
|
|
||||||
|
|
||||||
class Initializer_Tensor(ProcessGroupInitializer):
|
class Initializer_Tensor(ProcessGroupInitializer):
|
||||||
|
|
|
@ -9,6 +9,7 @@ import torch.distributed as dist
|
||||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||||
|
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
|
from internlm.core.context.process_group_initializer import ParallelMode
|
||||||
|
|
||||||
|
|
||||||
class BaseGradientHandler(ABC):
|
class BaseGradientHandler(ABC):
|
||||||
|
@ -74,3 +75,26 @@ class PipelineSharedModuleGradientHandler(BaseGradientHandler):
|
||||||
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=group)
|
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=group)
|
||||||
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
|
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
|
||||||
buf.copy_(synced)
|
buf.copy_(synced)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingSharedModuleGradientHandler(BaseGradientHandler):
|
||||||
|
"""A helper class to handle all-reduce operations in embedding share groups.
|
||||||
|
A all-reduce collective communication will be operated in
|
||||||
|
:func:`handle_gradient` among the first pipeline stage and the last pipeline stage.
|
||||||
|
For better performance, it bucketizes the gradients of all parameters that are
|
||||||
|
the same type to improve the efficiency of communication.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (Module): Model where the gradients accumulate.
|
||||||
|
optimizer (Optimizer): Optimizer for updating the parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def handle_gradient(self):
|
||||||
|
"""A method running a all-reduce operation in sub pipeline parallel groups."""
|
||||||
|
if gpc.is_pipeline_first_stage() or gpc.is_pipeline_last_stage():
|
||||||
|
weight = self._model.model.shared_embedding_weight()
|
||||||
|
grad = weight.grad
|
||||||
|
# enable zero will cause grad to be None
|
||||||
|
if grad is None:
|
||||||
|
grad = torch.zeros_like(weight)
|
||||||
|
torch.distributed.all_reduce(grad, group=gpc.get_group(parallel_mode=ParallelMode.EMBEDDING))
|
||||||
|
|
|
@ -14,7 +14,10 @@ from torch.utils.data import DataLoader
|
||||||
from internlm.core.context import ParallelMode
|
from internlm.core.context import ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.core.engine import Engine
|
from internlm.core.engine import Engine
|
||||||
from internlm.core.gradient_handler import PipelineSharedModuleGradientHandler
|
from internlm.core.gradient_handler import (
|
||||||
|
EmbeddingSharedModuleGradientHandler,
|
||||||
|
PipelineSharedModuleGradientHandler,
|
||||||
|
)
|
||||||
from internlm.core.scheduler import (
|
from internlm.core.scheduler import (
|
||||||
InterleavedPipelineScheduler,
|
InterleavedPipelineScheduler,
|
||||||
NonPipelineScheduler,
|
NonPipelineScheduler,
|
||||||
|
@ -68,8 +71,12 @@ def initialize_trainer(
|
||||||
assert isinstance(optimizer, BaseOptimizer), "optimizer must be instance of BaseOptimizer"
|
assert isinstance(optimizer, BaseOptimizer), "optimizer must be instance of BaseOptimizer"
|
||||||
|
|
||||||
# gradient handler, only support PipelineSharedModuleGradientHandler now
|
# gradient handler, only support PipelineSharedModuleGradientHandler now
|
||||||
|
# TODO: can refactor code here
|
||||||
if gpc.is_using_pp():
|
if gpc.is_using_pp():
|
||||||
gpc.config.gradient_handler = [dict(type="PipelineSharedModuleGradientHandler")]
|
gpc.config.gradient_handler = [
|
||||||
|
dict(type="PipelineSharedModuleGradientHandler"),
|
||||||
|
dict(type="EmbeddingSharedModuleGradientHandler"),
|
||||||
|
]
|
||||||
gradient_handler_cfg = gpc.config.get("gradient_handler", [])
|
gradient_handler_cfg = gpc.config.get("gradient_handler", [])
|
||||||
gradient_handlers = []
|
gradient_handlers = []
|
||||||
assert isinstance(gradient_handler_cfg, list), f"gradient_handler must be list but got {type(gradient_handler_cfg)}"
|
assert isinstance(gradient_handler_cfg, list), f"gradient_handler must be list but got {type(gradient_handler_cfg)}"
|
||||||
|
@ -77,6 +84,14 @@ def initialize_trainer(
|
||||||
if isinstance(config, dict) and config.get("type") == "PipelineSharedModuleGradientHandler":
|
if isinstance(config, dict) and config.get("type") == "PipelineSharedModuleGradientHandler":
|
||||||
handler = PipelineSharedModuleGradientHandler(model=model, optimizer=optimizer)
|
handler = PipelineSharedModuleGradientHandler(model=model, optimizer=optimizer)
|
||||||
gradient_handlers.append(handler)
|
gradient_handlers.append(handler)
|
||||||
|
if (
|
||||||
|
isinstance(config, dict)
|
||||||
|
and config.get("type") == "EmbeddingSharedModuleGradientHandler"
|
||||||
|
and gpc.config.model.get("tie_embeddings_and_output_weights", False)
|
||||||
|
and gpc.pipeline_parallel_size > 1
|
||||||
|
):
|
||||||
|
handler = EmbeddingSharedModuleGradientHandler(model=model, optimizer=optimizer)
|
||||||
|
gradient_handlers.append(handler)
|
||||||
|
|
||||||
# initialize scheduler for trainer
|
# initialize scheduler for trainer
|
||||||
scheduler = None
|
scheduler = None
|
||||||
|
|
|
@ -293,6 +293,8 @@ def args_sanity_check():
|
||||||
model._add_item("moe_use_residual", False)
|
model._add_item("moe_use_residual", False)
|
||||||
if "moe_gate_k" not in model:
|
if "moe_gate_k" not in model:
|
||||||
model._add_item("moe_gate_k", 2)
|
model._add_item("moe_gate_k", 2)
|
||||||
|
if "tie_embeddings_and_output_weights" not in model:
|
||||||
|
model._add_item("tie_embeddings_and_output_weights", False)
|
||||||
assert not (
|
assert not (
|
||||||
gpc.config.model.num_experts > 1 and gpc.config.parallel.zero1.fsdp
|
gpc.config.model.num_experts > 1 and gpc.config.parallel.zero1.fsdp
|
||||||
), "FSDP does not support num_experts > 1"
|
), "FSDP does not support num_experts > 1"
|
||||||
|
|
|
@ -40,22 +40,30 @@ class ScaleColumnParallelLinear(nn.Linear):
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
weight_scale: int = 1,
|
weight_scale: int = 1,
|
||||||
|
skip_weight_alloction: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
world_size = torch.distributed.get_world_size(process_group)
|
world_size = torch.distributed.get_world_size(process_group)
|
||||||
if out_features % world_size != 0:
|
if out_features % world_size != 0:
|
||||||
raise ValueError(f"out_features ({out_features}) must be divisible by " f"world_size ({world_size})")
|
raise ValueError(f"out_features ({out_features}) must be divisible by " f"world_size ({world_size})")
|
||||||
super().__init__(in_features, out_features // world_size, bias=bias, device=device, dtype=dtype)
|
super().__init__(in_features, out_features // world_size, bias=bias, device=device, dtype=dtype)
|
||||||
|
if skip_weight_alloction:
|
||||||
|
del self.weight
|
||||||
|
self.register_parameter("weight", None)
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
self.weight_scale = weight_scale
|
self.weight_scale = weight_scale
|
||||||
|
|
||||||
def forward(self, input): # pylint: disable=W0622
|
def forward(self, input, shared_weight: Optional[torch.Tensor] = None): # pylint: disable=W0622
|
||||||
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
|
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
|
||||||
# we do an all_gather of x before doing the matmul.
|
# we do an all_gather of x before doing the matmul.
|
||||||
# If not, then the input is already gathered.
|
# If not, then the input is already gathered.
|
||||||
|
if shared_weight is None:
|
||||||
|
if self.weight is None:
|
||||||
|
raise RuntimeError("weight was not given in forward pass " "and skip_weight_allocation is True.")
|
||||||
|
shared_weight = self.weight
|
||||||
if self.weight_scale != 1:
|
if self.weight_scale != 1:
|
||||||
weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach()
|
weight = shared_weight * self.weight_scale + (1 - self.weight_scale) * shared_weight.detach()
|
||||||
else:
|
else:
|
||||||
weight = self.weight
|
weight = shared_weight
|
||||||
return fused_dense_func_torch(
|
return fused_dense_func_torch(
|
||||||
input,
|
input,
|
||||||
weight,
|
weight,
|
||||||
|
@ -91,7 +99,11 @@ class RewardModelLinear(ScaleColumnParallelLinear):
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
weight_scale: int = 1,
|
weight_scale: int = 1,
|
||||||
|
skip_weight_alloction: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
# TODO have not use RewardModelLinear for now
|
||||||
|
assert not skip_weight_alloction, "shared weight not support here for now"
|
||||||
|
|
||||||
super().__init__(in_features, out_features, process_group, bias, device, dtype, weight_scale)
|
super().__init__(in_features, out_features, process_group, bias, device, dtype, weight_scale)
|
||||||
torch.distributed.broadcast(self.weight, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], process_group)
|
torch.distributed.broadcast(self.weight, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], process_group)
|
||||||
if bias:
|
if bias:
|
||||||
|
@ -102,7 +114,10 @@ class RewardModelLinear(ScaleColumnParallelLinear):
|
||||||
# we do an all_gather of x before doing the matmul.
|
# we do an all_gather of x before doing the matmul.
|
||||||
# If not, then the input is already gathered.
|
# If not, then the input is already gathered.
|
||||||
if self.weight_scale != 1:
|
if self.weight_scale != 1:
|
||||||
weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach()
|
weight = (
|
||||||
|
self.weight * self.weight_scale
|
||||||
|
+ (1 - self.weight_scale) * self.weight.detach() # pylint: disable=not-callable
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
weight = self.weight
|
weight = self.weight
|
||||||
return fused_dense_func_torch(
|
return fused_dense_func_torch(
|
||||||
|
|
|
@ -331,6 +331,7 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
moe_use_rts (bool, optional): default=True, whether to use Random Token Selection.
|
moe_use_rts (bool, optional): default=True, whether to use Random Token Selection.
|
||||||
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
|
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
|
||||||
(https://arxiv.org/abs/2201.05596) layer.
|
(https://arxiv.org/abs/2201.05596) layer.
|
||||||
|
tie_embeddings_and_output_weights: embedding and output layer share the same weight.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -370,9 +371,15 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
moe_drop_tokens: bool = True,
|
moe_drop_tokens: bool = True,
|
||||||
moe_use_rts: bool = True,
|
moe_use_rts: bool = True,
|
||||||
moe_use_residual: bool = False,
|
moe_use_residual: bool = False,
|
||||||
|
tie_embeddings_and_output_weights: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
assert not (
|
||||||
|
embed_split_hidden and tie_embeddings_and_output_weights
|
||||||
|
), "shared embedding weights is not supported when embed_split_hidden is True."
|
||||||
|
self.tie_embeddings_and_output_weights = tie_embeddings_and_output_weights
|
||||||
|
|
||||||
checkpoint_layer_num = int(num_layers * checkpoint)
|
checkpoint_layer_num = int(num_layers * checkpoint)
|
||||||
|
|
||||||
if is_reward:
|
if is_reward:
|
||||||
|
@ -446,6 +453,7 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
weight_scale=embed_grad_scale,
|
weight_scale=embed_grad_scale,
|
||||||
|
skip_weight_alloction=self.tie_embeddings_and_output_weights,
|
||||||
)
|
)
|
||||||
for _, param in self.head.named_parameters():
|
for _, param in self.head.named_parameters():
|
||||||
normal_(std=0.0052)(param)
|
normal_(std=0.0052)(param)
|
||||||
|
@ -453,6 +461,9 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
setattr(param, IS_TENSOR_PARALLEL, True)
|
setattr(param, IS_TENSOR_PARALLEL, True)
|
||||||
self.parallel_output = parallel_output
|
self.parallel_output = parallel_output
|
||||||
|
|
||||||
|
if self.tie_embeddings_and_output_weights:
|
||||||
|
self.initialize_word_embeddings(hidden_size, vocab_size, dtype, device)
|
||||||
|
|
||||||
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
|
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
|
||||||
# attention_mask: compute attention on the places where the value is 1
|
# attention_mask: compute attention on the places where the value is 1
|
||||||
# old condition may fail when use shared embedding
|
# old condition may fail when use shared embedding
|
||||||
|
@ -491,12 +502,79 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
if hasattr(self, "norm"):
|
if hasattr(self, "norm"):
|
||||||
hidden_states = self.norm(hidden_states.float())
|
hidden_states = self.norm(hidden_states.float())
|
||||||
if hasattr(self, "head"):
|
if hasattr(self, "head"):
|
||||||
|
if self.tie_embeddings_and_output_weights:
|
||||||
|
hidden_states = self.head(hidden_states, self.shared_embedding_weight())
|
||||||
|
else:
|
||||||
hidden_states = self.head(hidden_states)
|
hidden_states = self.head(hidden_states)
|
||||||
|
|
||||||
if not self.parallel_output:
|
if not self.parallel_output:
|
||||||
hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1)
|
hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1)
|
||||||
return hidden_states, moe_losses
|
return hidden_states, moe_losses
|
||||||
|
|
||||||
|
def shared_embedding_weight(self):
|
||||||
|
if not self.tie_embeddings_and_output_weights:
|
||||||
|
raise Exception(
|
||||||
|
"shared_embedding_weight() called for last stage, but share_embeddings_and_output_weights is false"
|
||||||
|
)
|
||||||
|
assert isinstance(self.embedding, ParallelGPT2Embeddings)
|
||||||
|
|
||||||
|
return self.embedding.word_embeddings.weight
|
||||||
|
|
||||||
|
# TODO: refactor code
|
||||||
|
def initialize_word_embeddings(
|
||||||
|
self,
|
||||||
|
hidden_size: int = 768,
|
||||||
|
vocab_size: int = 50304,
|
||||||
|
dtype: torch.dtype = torch.float,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
):
|
||||||
|
if not self.tie_embeddings_and_output_weights:
|
||||||
|
raise Exception("initialize_word_embeddings() was called but " "tie_embeddings_and_output_weights is false")
|
||||||
|
|
||||||
|
# This function just initializes the word embeddings in the final stage
|
||||||
|
# when we are using pipeline parallelism. Nothing to do if we aren't
|
||||||
|
# using pipeline parallelism.
|
||||||
|
if gpc.get_world_size(ParallelMode.PIPELINE) == 1:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Parameters are shared between the word embeddings layers, and the
|
||||||
|
# heads at the end of the model. In a pipelined setup with more than
|
||||||
|
# one stage, the initial embedding layer and the head are on different
|
||||||
|
# workers, so we do the following:
|
||||||
|
# 1. Create a second copy of word_embeddings on the last stage, with
|
||||||
|
# initial parameters of 0.0.
|
||||||
|
# 2. Do an all-reduce between the first and last stage to ensure that
|
||||||
|
# the two copies of word_embeddings start off with the same
|
||||||
|
# parameter values.
|
||||||
|
# 3. In the training loop, before an all-reduce between the grads of
|
||||||
|
# the two word_embeddings layers to ensure that every applied weight
|
||||||
|
# update is the same on both stages.
|
||||||
|
if gpc.is_pipeline_last_stage():
|
||||||
|
assert not gpc.is_pipeline_first_stage()
|
||||||
|
# set word_embeddings weights to 0 here, then copy first
|
||||||
|
# stage's weights using all_reduce below.
|
||||||
|
self.embedding = ParallelGPT2Embeddings(
|
||||||
|
embed_dim=hidden_size,
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
max_position_embeddings=-1,
|
||||||
|
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||||
|
padding_idx=None,
|
||||||
|
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
for _, param in self.embedding.named_parameters():
|
||||||
|
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||||
|
setattr(param, IS_TENSOR_PARALLEL, True)
|
||||||
|
self.shared_embedding_weight().data.fill_(0)
|
||||||
|
|
||||||
|
# Ensure that first and last stages have the same initial parameter
|
||||||
|
# values.
|
||||||
|
if gpc.is_pipeline_first_stage() or gpc.is_pipeline_last_stage():
|
||||||
|
torch.distributed.all_reduce(
|
||||||
|
self.shared_embedding_weight().data, group=gpc.get_group(ParallelMode.EMBEDDING)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"), **kwargs):
|
def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"), **kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -572,6 +650,7 @@ def build_model_with_moe_cfg(
|
||||||
moe_drop_tokens: bool = True,
|
moe_drop_tokens: bool = True,
|
||||||
moe_use_rts: bool = True,
|
moe_use_rts: bool = True,
|
||||||
moe_use_residual: bool = False,
|
moe_use_residual: bool = False,
|
||||||
|
tie_embeddings_and_output_weights=False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Build model with config.
|
Build model with config.
|
||||||
|
@ -613,6 +692,7 @@ def build_model_with_moe_cfg(
|
||||||
moe_use_rts (bool, optional): default=True, whether to use Random Token Selection.
|
moe_use_rts (bool, optional): default=True, whether to use Random Token Selection.
|
||||||
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
|
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
|
||||||
(https://arxiv.org/abs/2201.05596) layer.
|
(https://arxiv.org/abs/2201.05596) layer.
|
||||||
|
tie_embeddings_and_output_weights: embedding and output layer share the same weight.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
cfg = dict(
|
cfg = dict(
|
||||||
|
@ -646,6 +726,7 @@ def build_model_with_moe_cfg(
|
||||||
moe_drop_tokens=moe_drop_tokens,
|
moe_drop_tokens=moe_drop_tokens,
|
||||||
moe_use_rts=moe_use_rts,
|
moe_use_rts=moe_use_rts,
|
||||||
moe_use_residual=moe_use_residual,
|
moe_use_residual=moe_use_residual,
|
||||||
|
tie_embeddings_and_output_weights=tie_embeddings_and_output_weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg)
|
return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg)
|
||||||
|
|
Loading…
Reference in New Issue