mirror of https://github.com/InternLM/InternLM
add share embedding weight support for moe
parent
eeef07934a
commit
bf6dbf07fa
|
@ -129,6 +129,7 @@ model = dict(
|
|||
checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
|
||||
num_attention_heads=NUM_ATTENTION_HEAD,
|
||||
embed_split_hidden=True,
|
||||
tie_embeddings_and_output_weights=False,
|
||||
vocab_size=VOCAB_SIZE,
|
||||
embed_grad_scale=1,
|
||||
parallel_output=True,
|
||||
|
|
|
@ -48,6 +48,9 @@ class ParallelMode(Enum):
|
|||
# expert data parallel
|
||||
EXPERT_DATA = "expert_data"
|
||||
|
||||
# embedding share
|
||||
EMBEDDING = "embedding"
|
||||
|
||||
# dummy mode, only used during mode construction
|
||||
DUMMY = "dummy"
|
||||
|
||||
|
@ -236,8 +239,8 @@ class Initializer_Pipeline(ProcessGroupInitializer):
|
|||
process_group = None
|
||||
cpu_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.PIPELINE
|
||||
|
||||
groups = []
|
||||
for i in range(self.data_parallel_size):
|
||||
for j in range(self.pipeline_stage_size):
|
||||
ranks = list(
|
||||
|
@ -265,7 +268,37 @@ class Initializer_Pipeline(ProcessGroupInitializer):
|
|||
cpu_group = group_cpu
|
||||
ranks_in_group = ranks
|
||||
|
||||
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
|
||||
groups.append(
|
||||
(local_rank, group_world_size, process_group, cpu_group, ranks_in_group, ParallelMode.PIPELINE)
|
||||
)
|
||||
|
||||
# create embedding commuication group
|
||||
if len(ranks) > 1:
|
||||
embedding_ranks = [ranks[0], ranks[-1]]
|
||||
else:
|
||||
embedding_ranks = ranks
|
||||
embed_group = dist.new_group(embedding_ranks, timeout=LLM_NCCL_TIMEOUT)
|
||||
if use_cpu:
|
||||
group_cpu = (
|
||||
dist.new_group(embedding_ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
|
||||
if dist.get_backend() != "gloo"
|
||||
else embed_group
|
||||
)
|
||||
else:
|
||||
group_cpu = None
|
||||
|
||||
if self.rank in ranks:
|
||||
local_rank = ranks.index(self.rank)
|
||||
group_world_size = len(embedding_ranks)
|
||||
process_group = embed_group
|
||||
cpu_group = group_cpu
|
||||
ranks_in_group = embedding_ranks
|
||||
|
||||
groups.append(
|
||||
(local_rank, group_world_size, process_group, cpu_group, ranks_in_group, ParallelMode.EMBEDDING)
|
||||
)
|
||||
|
||||
return groups
|
||||
|
||||
|
||||
class Initializer_Tensor(ProcessGroupInitializer):
|
||||
|
|
|
@ -9,6 +9,7 @@ import torch.distributed as dist
|
|||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.context.process_group_initializer import ParallelMode
|
||||
|
||||
|
||||
class BaseGradientHandler(ABC):
|
||||
|
@ -74,3 +75,26 @@ class PipelineSharedModuleGradientHandler(BaseGradientHandler):
|
|||
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=group)
|
||||
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
|
||||
buf.copy_(synced)
|
||||
|
||||
|
||||
class EmbeddingSharedModuleGradientHandler(BaseGradientHandler):
|
||||
"""A helper class to handle all-reduce operations in embedding share groups.
|
||||
A all-reduce collective communication will be operated in
|
||||
:func:`handle_gradient` among the first pipeline stage and the last pipeline stage.
|
||||
For better performance, it bucketizes the gradients of all parameters that are
|
||||
the same type to improve the efficiency of communication.
|
||||
|
||||
Args:
|
||||
model (Module): Model where the gradients accumulate.
|
||||
optimizer (Optimizer): Optimizer for updating the parameters.
|
||||
"""
|
||||
|
||||
def handle_gradient(self):
|
||||
"""A method running a all-reduce operation in sub pipeline parallel groups."""
|
||||
if gpc.is_pipeline_first_stage() or gpc.is_pipeline_last_stage():
|
||||
weight = self._model.model.shared_embedding_weight()
|
||||
grad = weight.grad
|
||||
# enable zero will cause grad to be None
|
||||
if grad is None:
|
||||
grad = torch.zeros_like(weight)
|
||||
torch.distributed.all_reduce(grad, group=gpc.get_group(parallel_mode=ParallelMode.EMBEDDING))
|
||||
|
|
|
@ -14,7 +14,10 @@ from torch.utils.data import DataLoader
|
|||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.engine import Engine
|
||||
from internlm.core.gradient_handler import PipelineSharedModuleGradientHandler
|
||||
from internlm.core.gradient_handler import (
|
||||
EmbeddingSharedModuleGradientHandler,
|
||||
PipelineSharedModuleGradientHandler,
|
||||
)
|
||||
from internlm.core.scheduler import (
|
||||
InterleavedPipelineScheduler,
|
||||
NonPipelineScheduler,
|
||||
|
@ -68,8 +71,12 @@ def initialize_trainer(
|
|||
assert isinstance(optimizer, BaseOptimizer), "optimizer must be instance of BaseOptimizer"
|
||||
|
||||
# gradient handler, only support PipelineSharedModuleGradientHandler now
|
||||
# TODO: can refactor code here
|
||||
if gpc.is_using_pp():
|
||||
gpc.config.gradient_handler = [dict(type="PipelineSharedModuleGradientHandler")]
|
||||
gpc.config.gradient_handler = [
|
||||
dict(type="PipelineSharedModuleGradientHandler"),
|
||||
dict(type="EmbeddingSharedModuleGradientHandler"),
|
||||
]
|
||||
gradient_handler_cfg = gpc.config.get("gradient_handler", [])
|
||||
gradient_handlers = []
|
||||
assert isinstance(gradient_handler_cfg, list), f"gradient_handler must be list but got {type(gradient_handler_cfg)}"
|
||||
|
@ -77,6 +84,14 @@ def initialize_trainer(
|
|||
if isinstance(config, dict) and config.get("type") == "PipelineSharedModuleGradientHandler":
|
||||
handler = PipelineSharedModuleGradientHandler(model=model, optimizer=optimizer)
|
||||
gradient_handlers.append(handler)
|
||||
if (
|
||||
isinstance(config, dict)
|
||||
and config.get("type") == "EmbeddingSharedModuleGradientHandler"
|
||||
and gpc.config.model.get("tie_embeddings_and_output_weights", False)
|
||||
and gpc.pipeline_parallel_size > 1
|
||||
):
|
||||
handler = EmbeddingSharedModuleGradientHandler(model=model, optimizer=optimizer)
|
||||
gradient_handlers.append(handler)
|
||||
|
||||
# initialize scheduler for trainer
|
||||
scheduler = None
|
||||
|
|
|
@ -293,6 +293,8 @@ def args_sanity_check():
|
|||
model._add_item("moe_use_residual", False)
|
||||
if "moe_gate_k" not in model:
|
||||
model._add_item("moe_gate_k", 2)
|
||||
if "tie_embeddings_and_output_weights" not in model:
|
||||
model._add_item("tie_embeddings_and_output_weights", False)
|
||||
assert not (
|
||||
gpc.config.model.num_experts > 1 and gpc.config.parallel.zero1.fsdp
|
||||
), "FSDP does not support num_experts > 1"
|
||||
|
|
|
@ -40,22 +40,30 @@ class ScaleColumnParallelLinear(nn.Linear):
|
|||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
weight_scale: int = 1,
|
||||
skip_weight_alloction: bool = False,
|
||||
) -> None:
|
||||
world_size = torch.distributed.get_world_size(process_group)
|
||||
if out_features % world_size != 0:
|
||||
raise ValueError(f"out_features ({out_features}) must be divisible by " f"world_size ({world_size})")
|
||||
super().__init__(in_features, out_features // world_size, bias=bias, device=device, dtype=dtype)
|
||||
if skip_weight_alloction:
|
||||
del self.weight
|
||||
self.register_parameter("weight", None)
|
||||
self.process_group = process_group
|
||||
self.weight_scale = weight_scale
|
||||
|
||||
def forward(self, input): # pylint: disable=W0622
|
||||
def forward(self, input, shared_weight: Optional[torch.Tensor] = None): # pylint: disable=W0622
|
||||
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
|
||||
# we do an all_gather of x before doing the matmul.
|
||||
# If not, then the input is already gathered.
|
||||
if shared_weight is None:
|
||||
if self.weight is None:
|
||||
raise RuntimeError("weight was not given in forward pass " "and skip_weight_allocation is True.")
|
||||
shared_weight = self.weight
|
||||
if self.weight_scale != 1:
|
||||
weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach()
|
||||
weight = shared_weight * self.weight_scale + (1 - self.weight_scale) * shared_weight.detach()
|
||||
else:
|
||||
weight = self.weight
|
||||
weight = shared_weight
|
||||
return fused_dense_func_torch(
|
||||
input,
|
||||
weight,
|
||||
|
@ -91,7 +99,11 @@ class RewardModelLinear(ScaleColumnParallelLinear):
|
|||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
weight_scale: int = 1,
|
||||
skip_weight_alloction: bool = False,
|
||||
) -> None:
|
||||
# TODO have not use RewardModelLinear for now
|
||||
assert not skip_weight_alloction, "shared weight not support here for now"
|
||||
|
||||
super().__init__(in_features, out_features, process_group, bias, device, dtype, weight_scale)
|
||||
torch.distributed.broadcast(self.weight, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], process_group)
|
||||
if bias:
|
||||
|
@ -102,7 +114,10 @@ class RewardModelLinear(ScaleColumnParallelLinear):
|
|||
# we do an all_gather of x before doing the matmul.
|
||||
# If not, then the input is already gathered.
|
||||
if self.weight_scale != 1:
|
||||
weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach()
|
||||
weight = (
|
||||
self.weight * self.weight_scale
|
||||
+ (1 - self.weight_scale) * self.weight.detach() # pylint: disable=not-callable
|
||||
)
|
||||
else:
|
||||
weight = self.weight
|
||||
return fused_dense_func_torch(
|
||||
|
|
|
@ -331,6 +331,7 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
moe_use_rts (bool, optional): default=True, whether to use Random Token Selection.
|
||||
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
|
||||
(https://arxiv.org/abs/2201.05596) layer.
|
||||
tie_embeddings_and_output_weights: embedding and output layer share the same weight.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -370,9 +371,15 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
moe_drop_tokens: bool = True,
|
||||
moe_use_rts: bool = True,
|
||||
moe_use_residual: bool = False,
|
||||
tie_embeddings_and_output_weights: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert not (
|
||||
embed_split_hidden and tie_embeddings_and_output_weights
|
||||
), "shared embedding weights is not supported when embed_split_hidden is True."
|
||||
self.tie_embeddings_and_output_weights = tie_embeddings_and_output_weights
|
||||
|
||||
checkpoint_layer_num = int(num_layers * checkpoint)
|
||||
|
||||
if is_reward:
|
||||
|
@ -446,6 +453,7 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
device=device,
|
||||
dtype=dtype,
|
||||
weight_scale=embed_grad_scale,
|
||||
skip_weight_alloction=self.tie_embeddings_and_output_weights,
|
||||
)
|
||||
for _, param in self.head.named_parameters():
|
||||
normal_(std=0.0052)(param)
|
||||
|
@ -453,6 +461,9 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
setattr(param, IS_TENSOR_PARALLEL, True)
|
||||
self.parallel_output = parallel_output
|
||||
|
||||
if self.tie_embeddings_and_output_weights:
|
||||
self.initialize_word_embeddings(hidden_size, vocab_size, dtype, device)
|
||||
|
||||
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
|
||||
# attention_mask: compute attention on the places where the value is 1
|
||||
# old condition may fail when use shared embedding
|
||||
|
@ -491,12 +502,79 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
if hasattr(self, "norm"):
|
||||
hidden_states = self.norm(hidden_states.float())
|
||||
if hasattr(self, "head"):
|
||||
hidden_states = self.head(hidden_states)
|
||||
if self.tie_embeddings_and_output_weights:
|
||||
hidden_states = self.head(hidden_states, self.shared_embedding_weight())
|
||||
else:
|
||||
hidden_states = self.head(hidden_states)
|
||||
|
||||
if not self.parallel_output:
|
||||
hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1)
|
||||
return hidden_states, moe_losses
|
||||
|
||||
def shared_embedding_weight(self):
|
||||
if not self.tie_embeddings_and_output_weights:
|
||||
raise Exception(
|
||||
"shared_embedding_weight() called for last stage, but share_embeddings_and_output_weights is false"
|
||||
)
|
||||
assert isinstance(self.embedding, ParallelGPT2Embeddings)
|
||||
|
||||
return self.embedding.word_embeddings.weight
|
||||
|
||||
# TODO: refactor code
|
||||
def initialize_word_embeddings(
|
||||
self,
|
||||
hidden_size: int = 768,
|
||||
vocab_size: int = 50304,
|
||||
dtype: torch.dtype = torch.float,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
if not self.tie_embeddings_and_output_weights:
|
||||
raise Exception("initialize_word_embeddings() was called but " "tie_embeddings_and_output_weights is false")
|
||||
|
||||
# This function just initializes the word embeddings in the final stage
|
||||
# when we are using pipeline parallelism. Nothing to do if we aren't
|
||||
# using pipeline parallelism.
|
||||
if gpc.get_world_size(ParallelMode.PIPELINE) == 1:
|
||||
return
|
||||
|
||||
# Parameters are shared between the word embeddings layers, and the
|
||||
# heads at the end of the model. In a pipelined setup with more than
|
||||
# one stage, the initial embedding layer and the head are on different
|
||||
# workers, so we do the following:
|
||||
# 1. Create a second copy of word_embeddings on the last stage, with
|
||||
# initial parameters of 0.0.
|
||||
# 2. Do an all-reduce between the first and last stage to ensure that
|
||||
# the two copies of word_embeddings start off with the same
|
||||
# parameter values.
|
||||
# 3. In the training loop, before an all-reduce between the grads of
|
||||
# the two word_embeddings layers to ensure that every applied weight
|
||||
# update is the same on both stages.
|
||||
if gpc.is_pipeline_last_stage():
|
||||
assert not gpc.is_pipeline_first_stage()
|
||||
# set word_embeddings weights to 0 here, then copy first
|
||||
# stage's weights using all_reduce below.
|
||||
self.embedding = ParallelGPT2Embeddings(
|
||||
embed_dim=hidden_size,
|
||||
vocab_size=vocab_size,
|
||||
max_position_embeddings=-1,
|
||||
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||
padding_idx=None,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
for _, param in self.embedding.named_parameters():
|
||||
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||
setattr(param, IS_TENSOR_PARALLEL, True)
|
||||
self.shared_embedding_weight().data.fill_(0)
|
||||
|
||||
# Ensure that first and last stages have the same initial parameter
|
||||
# values.
|
||||
if gpc.is_pipeline_first_stage() or gpc.is_pipeline_last_stage():
|
||||
torch.distributed.all_reduce(
|
||||
self.shared_embedding_weight().data, group=gpc.get_group(ParallelMode.EMBEDDING)
|
||||
)
|
||||
|
||||
|
||||
def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"), **kwargs):
|
||||
"""
|
||||
|
@ -572,6 +650,7 @@ def build_model_with_moe_cfg(
|
|||
moe_drop_tokens: bool = True,
|
||||
moe_use_rts: bool = True,
|
||||
moe_use_residual: bool = False,
|
||||
tie_embeddings_and_output_weights=False,
|
||||
):
|
||||
"""
|
||||
Build model with config.
|
||||
|
@ -613,6 +692,7 @@ def build_model_with_moe_cfg(
|
|||
moe_use_rts (bool, optional): default=True, whether to use Random Token Selection.
|
||||
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
|
||||
(https://arxiv.org/abs/2201.05596) layer.
|
||||
tie_embeddings_and_output_weights: embedding and output layer share the same weight.
|
||||
"""
|
||||
|
||||
cfg = dict(
|
||||
|
@ -646,6 +726,7 @@ def build_model_with_moe_cfg(
|
|||
moe_drop_tokens=moe_drop_tokens,
|
||||
moe_use_rts=moe_use_rts,
|
||||
moe_use_residual=moe_use_residual,
|
||||
tie_embeddings_and_output_weights=tie_embeddings_and_output_weights,
|
||||
)
|
||||
|
||||
return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg)
|
||||
|
|
Loading…
Reference in New Issue