add share embedding weight support for moe

pull/422/head
Wenwen Qu 2023-10-18 11:39:04 +08:00
parent eeef07934a
commit bf6dbf07fa
7 changed files with 180 additions and 9 deletions

View File

@ -129,6 +129,7 @@ model = dict(
checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
num_attention_heads=NUM_ATTENTION_HEAD,
embed_split_hidden=True,
tie_embeddings_and_output_weights=False,
vocab_size=VOCAB_SIZE,
embed_grad_scale=1,
parallel_output=True,

View File

@ -48,6 +48,9 @@ class ParallelMode(Enum):
# expert data parallel
EXPERT_DATA = "expert_data"
# embedding share
EMBEDDING = "embedding"
# dummy mode, only used during mode construction
DUMMY = "dummy"
@ -236,8 +239,8 @@ class Initializer_Pipeline(ProcessGroupInitializer):
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.PIPELINE
groups = []
for i in range(self.data_parallel_size):
for j in range(self.pipeline_stage_size):
ranks = list(
@ -265,7 +268,37 @@ class Initializer_Pipeline(ProcessGroupInitializer):
cpu_group = group_cpu
ranks_in_group = ranks
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
groups.append(
(local_rank, group_world_size, process_group, cpu_group, ranks_in_group, ParallelMode.PIPELINE)
)
# create embedding commuication group
if len(ranks) > 1:
embedding_ranks = [ranks[0], ranks[-1]]
else:
embedding_ranks = ranks
embed_group = dist.new_group(embedding_ranks, timeout=LLM_NCCL_TIMEOUT)
if use_cpu:
group_cpu = (
dist.new_group(embedding_ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
if dist.get_backend() != "gloo"
else embed_group
)
else:
group_cpu = None
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(embedding_ranks)
process_group = embed_group
cpu_group = group_cpu
ranks_in_group = embedding_ranks
groups.append(
(local_rank, group_world_size, process_group, cpu_group, ranks_in_group, ParallelMode.EMBEDDING)
)
return groups
class Initializer_Tensor(ProcessGroupInitializer):

View File

@ -9,6 +9,7 @@ import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from internlm.core.context import global_context as gpc
from internlm.core.context.process_group_initializer import ParallelMode
class BaseGradientHandler(ABC):
@ -74,3 +75,26 @@ class PipelineSharedModuleGradientHandler(BaseGradientHandler):
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=group)
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
class EmbeddingSharedModuleGradientHandler(BaseGradientHandler):
"""A helper class to handle all-reduce operations in embedding share groups.
A all-reduce collective communication will be operated in
:func:`handle_gradient` among the first pipeline stage and the last pipeline stage.
For better performance, it bucketizes the gradients of all parameters that are
the same type to improve the efficiency of communication.
Args:
model (Module): Model where the gradients accumulate.
optimizer (Optimizer): Optimizer for updating the parameters.
"""
def handle_gradient(self):
"""A method running a all-reduce operation in sub pipeline parallel groups."""
if gpc.is_pipeline_first_stage() or gpc.is_pipeline_last_stage():
weight = self._model.model.shared_embedding_weight()
grad = weight.grad
# enable zero will cause grad to be None
if grad is None:
grad = torch.zeros_like(weight)
torch.distributed.all_reduce(grad, group=gpc.get_group(parallel_mode=ParallelMode.EMBEDDING))

View File

@ -14,7 +14,10 @@ from torch.utils.data import DataLoader
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.engine import Engine
from internlm.core.gradient_handler import PipelineSharedModuleGradientHandler
from internlm.core.gradient_handler import (
EmbeddingSharedModuleGradientHandler,
PipelineSharedModuleGradientHandler,
)
from internlm.core.scheduler import (
InterleavedPipelineScheduler,
NonPipelineScheduler,
@ -68,8 +71,12 @@ def initialize_trainer(
assert isinstance(optimizer, BaseOptimizer), "optimizer must be instance of BaseOptimizer"
# gradient handler, only support PipelineSharedModuleGradientHandler now
# TODO: can refactor code here
if gpc.is_using_pp():
gpc.config.gradient_handler = [dict(type="PipelineSharedModuleGradientHandler")]
gpc.config.gradient_handler = [
dict(type="PipelineSharedModuleGradientHandler"),
dict(type="EmbeddingSharedModuleGradientHandler"),
]
gradient_handler_cfg = gpc.config.get("gradient_handler", [])
gradient_handlers = []
assert isinstance(gradient_handler_cfg, list), f"gradient_handler must be list but got {type(gradient_handler_cfg)}"
@ -77,6 +84,14 @@ def initialize_trainer(
if isinstance(config, dict) and config.get("type") == "PipelineSharedModuleGradientHandler":
handler = PipelineSharedModuleGradientHandler(model=model, optimizer=optimizer)
gradient_handlers.append(handler)
if (
isinstance(config, dict)
and config.get("type") == "EmbeddingSharedModuleGradientHandler"
and gpc.config.model.get("tie_embeddings_and_output_weights", False)
and gpc.pipeline_parallel_size > 1
):
handler = EmbeddingSharedModuleGradientHandler(model=model, optimizer=optimizer)
gradient_handlers.append(handler)
# initialize scheduler for trainer
scheduler = None

View File

@ -293,6 +293,8 @@ def args_sanity_check():
model._add_item("moe_use_residual", False)
if "moe_gate_k" not in model:
model._add_item("moe_gate_k", 2)
if "tie_embeddings_and_output_weights" not in model:
model._add_item("tie_embeddings_and_output_weights", False)
assert not (
gpc.config.model.num_experts > 1 and gpc.config.parallel.zero1.fsdp
), "FSDP does not support num_experts > 1"

View File

@ -40,22 +40,30 @@ class ScaleColumnParallelLinear(nn.Linear):
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
weight_scale: int = 1,
skip_weight_alloction: bool = False,
) -> None:
world_size = torch.distributed.get_world_size(process_group)
if out_features % world_size != 0:
raise ValueError(f"out_features ({out_features}) must be divisible by " f"world_size ({world_size})")
super().__init__(in_features, out_features // world_size, bias=bias, device=device, dtype=dtype)
if skip_weight_alloction:
del self.weight
self.register_parameter("weight", None)
self.process_group = process_group
self.weight_scale = weight_scale
def forward(self, input): # pylint: disable=W0622
def forward(self, input, shared_weight: Optional[torch.Tensor] = None): # pylint: disable=W0622
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
# we do an all_gather of x before doing the matmul.
# If not, then the input is already gathered.
if shared_weight is None:
if self.weight is None:
raise RuntimeError("weight was not given in forward pass " "and skip_weight_allocation is True.")
shared_weight = self.weight
if self.weight_scale != 1:
weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach()
weight = shared_weight * self.weight_scale + (1 - self.weight_scale) * shared_weight.detach()
else:
weight = self.weight
weight = shared_weight
return fused_dense_func_torch(
input,
weight,
@ -91,7 +99,11 @@ class RewardModelLinear(ScaleColumnParallelLinear):
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
weight_scale: int = 1,
skip_weight_alloction: bool = False,
) -> None:
# TODO have not use RewardModelLinear for now
assert not skip_weight_alloction, "shared weight not support here for now"
super().__init__(in_features, out_features, process_group, bias, device, dtype, weight_scale)
torch.distributed.broadcast(self.weight, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], process_group)
if bias:
@ -102,7 +114,10 @@ class RewardModelLinear(ScaleColumnParallelLinear):
# we do an all_gather of x before doing the matmul.
# If not, then the input is already gathered.
if self.weight_scale != 1:
weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach()
weight = (
self.weight * self.weight_scale
+ (1 - self.weight_scale) * self.weight.detach() # pylint: disable=not-callable
)
else:
weight = self.weight
return fused_dense_func_torch(

View File

@ -331,6 +331,7 @@ class PackedFlashInternLm1D(nn.Module):
moe_use_rts (bool, optional): default=True, whether to use Random Token Selection.
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
(https://arxiv.org/abs/2201.05596) layer.
tie_embeddings_and_output_weights: embedding and output layer share the same weight.
"""
def __init__(
@ -370,9 +371,15 @@ class PackedFlashInternLm1D(nn.Module):
moe_drop_tokens: bool = True,
moe_use_rts: bool = True,
moe_use_residual: bool = False,
tie_embeddings_and_output_weights: bool = True,
):
super().__init__()
assert not (
embed_split_hidden and tie_embeddings_and_output_weights
), "shared embedding weights is not supported when embed_split_hidden is True."
self.tie_embeddings_and_output_weights = tie_embeddings_and_output_weights
checkpoint_layer_num = int(num_layers * checkpoint)
if is_reward:
@ -446,6 +453,7 @@ class PackedFlashInternLm1D(nn.Module):
device=device,
dtype=dtype,
weight_scale=embed_grad_scale,
skip_weight_alloction=self.tie_embeddings_and_output_weights,
)
for _, param in self.head.named_parameters():
normal_(std=0.0052)(param)
@ -453,6 +461,9 @@ class PackedFlashInternLm1D(nn.Module):
setattr(param, IS_TENSOR_PARALLEL, True)
self.parallel_output = parallel_output
if self.tie_embeddings_and_output_weights:
self.initialize_word_embeddings(hidden_size, vocab_size, dtype, device)
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
# attention_mask: compute attention on the places where the value is 1
# old condition may fail when use shared embedding
@ -491,12 +502,79 @@ class PackedFlashInternLm1D(nn.Module):
if hasattr(self, "norm"):
hidden_states = self.norm(hidden_states.float())
if hasattr(self, "head"):
hidden_states = self.head(hidden_states)
if self.tie_embeddings_and_output_weights:
hidden_states = self.head(hidden_states, self.shared_embedding_weight())
else:
hidden_states = self.head(hidden_states)
if not self.parallel_output:
hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1)
return hidden_states, moe_losses
def shared_embedding_weight(self):
if not self.tie_embeddings_and_output_weights:
raise Exception(
"shared_embedding_weight() called for last stage, but share_embeddings_and_output_weights is false"
)
assert isinstance(self.embedding, ParallelGPT2Embeddings)
return self.embedding.word_embeddings.weight
# TODO: refactor code
def initialize_word_embeddings(
self,
hidden_size: int = 768,
vocab_size: int = 50304,
dtype: torch.dtype = torch.float,
device: Optional[torch.device] = None,
):
if not self.tie_embeddings_and_output_weights:
raise Exception("initialize_word_embeddings() was called but " "tie_embeddings_and_output_weights is false")
# This function just initializes the word embeddings in the final stage
# when we are using pipeline parallelism. Nothing to do if we aren't
# using pipeline parallelism.
if gpc.get_world_size(ParallelMode.PIPELINE) == 1:
return
# Parameters are shared between the word embeddings layers, and the
# heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different
# workers, so we do the following:
# 1. Create a second copy of word_embeddings on the last stage, with
# initial parameters of 0.0.
# 2. Do an all-reduce between the first and last stage to ensure that
# the two copies of word_embeddings start off with the same
# parameter values.
# 3. In the training loop, before an all-reduce between the grads of
# the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages.
if gpc.is_pipeline_last_stage():
assert not gpc.is_pipeline_first_stage()
# set word_embeddings weights to 0 here, then copy first
# stage's weights using all_reduce below.
self.embedding = ParallelGPT2Embeddings(
embed_dim=hidden_size,
vocab_size=vocab_size,
max_position_embeddings=-1,
process_group=gpc.get_group(ParallelMode.TENSOR),
padding_idx=None,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)
for _, param in self.embedding.named_parameters():
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
setattr(param, IS_TENSOR_PARALLEL, True)
self.shared_embedding_weight().data.fill_(0)
# Ensure that first and last stages have the same initial parameter
# values.
if gpc.is_pipeline_first_stage() or gpc.is_pipeline_last_stage():
torch.distributed.all_reduce(
self.shared_embedding_weight().data, group=gpc.get_group(ParallelMode.EMBEDDING)
)
def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"), **kwargs):
"""
@ -572,6 +650,7 @@ def build_model_with_moe_cfg(
moe_drop_tokens: bool = True,
moe_use_rts: bool = True,
moe_use_residual: bool = False,
tie_embeddings_and_output_weights=False,
):
"""
Build model with config.
@ -613,6 +692,7 @@ def build_model_with_moe_cfg(
moe_use_rts (bool, optional): default=True, whether to use Random Token Selection.
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
(https://arxiv.org/abs/2201.05596) layer.
tie_embeddings_and_output_weights: embedding and output layer share the same weight.
"""
cfg = dict(
@ -646,6 +726,7 @@ def build_model_with_moe_cfg(
moe_drop_tokens=moe_drop_tokens,
moe_use_rts=moe_use_rts,
moe_use_residual=moe_use_residual,
tie_embeddings_and_output_weights=tie_embeddings_and_output_weights,
)
return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg)