diff --git a/configs/7B_MoE4_sft.py b/configs/7B_MoE4_sft.py index 92a93d0..dd0f782 100644 --- a/configs/7B_MoE4_sft.py +++ b/configs/7B_MoE4_sft.py @@ -129,6 +129,7 @@ model = dict( checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] num_attention_heads=NUM_ATTENTION_HEAD, embed_split_hidden=True, + tie_embeddings_and_output_weights=False, vocab_size=VOCAB_SIZE, embed_grad_scale=1, parallel_output=True, diff --git a/internlm/core/context/process_group_initializer.py b/internlm/core/context/process_group_initializer.py index e9afa2e..4b45d65 100644 --- a/internlm/core/context/process_group_initializer.py +++ b/internlm/core/context/process_group_initializer.py @@ -48,6 +48,9 @@ class ParallelMode(Enum): # expert data parallel EXPERT_DATA = "expert_data" + # embedding share + EMBEDDING = "embedding" + # dummy mode, only used during mode construction DUMMY = "dummy" @@ -236,8 +239,8 @@ class Initializer_Pipeline(ProcessGroupInitializer): process_group = None cpu_group = None group_world_size = None - mode = ParallelMode.PIPELINE + groups = [] for i in range(self.data_parallel_size): for j in range(self.pipeline_stage_size): ranks = list( @@ -265,7 +268,37 @@ class Initializer_Pipeline(ProcessGroupInitializer): cpu_group = group_cpu ranks_in_group = ranks - return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode + groups.append( + (local_rank, group_world_size, process_group, cpu_group, ranks_in_group, ParallelMode.PIPELINE) + ) + + # create embedding commuication group + if len(ranks) > 1: + embedding_ranks = [ranks[0], ranks[-1]] + else: + embedding_ranks = ranks + embed_group = dist.new_group(embedding_ranks, timeout=LLM_NCCL_TIMEOUT) + if use_cpu: + group_cpu = ( + dist.new_group(embedding_ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT) + if dist.get_backend() != "gloo" + else embed_group + ) + else: + group_cpu = None + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(embedding_ranks) + process_group = embed_group + cpu_group = group_cpu + ranks_in_group = embedding_ranks + + groups.append( + (local_rank, group_world_size, process_group, cpu_group, ranks_in_group, ParallelMode.EMBEDDING) + ) + + return groups class Initializer_Tensor(ProcessGroupInitializer): diff --git a/internlm/core/gradient_handler.py b/internlm/core/gradient_handler.py index f2aaa1d..c6ac318 100644 --- a/internlm/core/gradient_handler.py +++ b/internlm/core/gradient_handler.py @@ -9,6 +9,7 @@ import torch.distributed as dist from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from internlm.core.context import global_context as gpc +from internlm.core.context.process_group_initializer import ParallelMode class BaseGradientHandler(ABC): @@ -74,3 +75,26 @@ class PipelineSharedModuleGradientHandler(BaseGradientHandler): dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=group) for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): buf.copy_(synced) + + +class EmbeddingSharedModuleGradientHandler(BaseGradientHandler): + """A helper class to handle all-reduce operations in embedding share groups. + A all-reduce collective communication will be operated in + :func:`handle_gradient` among the first pipeline stage and the last pipeline stage. + For better performance, it bucketizes the gradients of all parameters that are + the same type to improve the efficiency of communication. + + Args: + model (Module): Model where the gradients accumulate. + optimizer (Optimizer): Optimizer for updating the parameters. + """ + + def handle_gradient(self): + """A method running a all-reduce operation in sub pipeline parallel groups.""" + if gpc.is_pipeline_first_stage() or gpc.is_pipeline_last_stage(): + weight = self._model.model.shared_embedding_weight() + grad = weight.grad + # enable zero will cause grad to be None + if grad is None: + grad = torch.zeros_like(weight) + torch.distributed.all_reduce(grad, group=gpc.get_group(parallel_mode=ParallelMode.EMBEDDING)) diff --git a/internlm/initialize/initialize_trainer.py b/internlm/initialize/initialize_trainer.py index beb4a40..e6c779c 100644 --- a/internlm/initialize/initialize_trainer.py +++ b/internlm/initialize/initialize_trainer.py @@ -14,7 +14,10 @@ from torch.utils.data import DataLoader from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.engine import Engine -from internlm.core.gradient_handler import PipelineSharedModuleGradientHandler +from internlm.core.gradient_handler import ( + EmbeddingSharedModuleGradientHandler, + PipelineSharedModuleGradientHandler, +) from internlm.core.scheduler import ( InterleavedPipelineScheduler, NonPipelineScheduler, @@ -68,8 +71,12 @@ def initialize_trainer( assert isinstance(optimizer, BaseOptimizer), "optimizer must be instance of BaseOptimizer" # gradient handler, only support PipelineSharedModuleGradientHandler now + # TODO: can refactor code here if gpc.is_using_pp(): - gpc.config.gradient_handler = [dict(type="PipelineSharedModuleGradientHandler")] + gpc.config.gradient_handler = [ + dict(type="PipelineSharedModuleGradientHandler"), + dict(type="EmbeddingSharedModuleGradientHandler"), + ] gradient_handler_cfg = gpc.config.get("gradient_handler", []) gradient_handlers = [] assert isinstance(gradient_handler_cfg, list), f"gradient_handler must be list but got {type(gradient_handler_cfg)}" @@ -77,6 +84,14 @@ def initialize_trainer( if isinstance(config, dict) and config.get("type") == "PipelineSharedModuleGradientHandler": handler = PipelineSharedModuleGradientHandler(model=model, optimizer=optimizer) gradient_handlers.append(handler) + if ( + isinstance(config, dict) + and config.get("type") == "EmbeddingSharedModuleGradientHandler" + and gpc.config.model.get("tie_embeddings_and_output_weights", False) + and gpc.pipeline_parallel_size > 1 + ): + handler = EmbeddingSharedModuleGradientHandler(model=model, optimizer=optimizer) + gradient_handlers.append(handler) # initialize scheduler for trainer scheduler = None diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 2087ae4..062bdb2 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -293,6 +293,8 @@ def args_sanity_check(): model._add_item("moe_use_residual", False) if "moe_gate_k" not in model: model._add_item("moe_gate_k", 2) + if "tie_embeddings_and_output_weights" not in model: + model._add_item("tie_embeddings_and_output_weights", False) assert not ( gpc.config.model.num_experts > 1 and gpc.config.parallel.zero1.fsdp ), "FSDP does not support num_experts > 1" diff --git a/internlm/model/linear.py b/internlm/model/linear.py index d18308a..8ab8707 100644 --- a/internlm/model/linear.py +++ b/internlm/model/linear.py @@ -40,22 +40,30 @@ class ScaleColumnParallelLinear(nn.Linear): device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, weight_scale: int = 1, + skip_weight_alloction: bool = False, ) -> None: world_size = torch.distributed.get_world_size(process_group) if out_features % world_size != 0: raise ValueError(f"out_features ({out_features}) must be divisible by " f"world_size ({world_size})") super().__init__(in_features, out_features // world_size, bias=bias, device=device, dtype=dtype) + if skip_weight_alloction: + del self.weight + self.register_parameter("weight", None) self.process_group = process_group self.weight_scale = weight_scale - def forward(self, input): # pylint: disable=W0622 + def forward(self, input, shared_weight: Optional[torch.Tensor] = None): # pylint: disable=W0622 # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: # we do an all_gather of x before doing the matmul. # If not, then the input is already gathered. + if shared_weight is None: + if self.weight is None: + raise RuntimeError("weight was not given in forward pass " "and skip_weight_allocation is True.") + shared_weight = self.weight if self.weight_scale != 1: - weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach() + weight = shared_weight * self.weight_scale + (1 - self.weight_scale) * shared_weight.detach() else: - weight = self.weight + weight = shared_weight return fused_dense_func_torch( input, weight, @@ -91,7 +99,11 @@ class RewardModelLinear(ScaleColumnParallelLinear): device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, weight_scale: int = 1, + skip_weight_alloction: bool = False, ) -> None: + # TODO have not use RewardModelLinear for now + assert not skip_weight_alloction, "shared weight not support here for now" + super().__init__(in_features, out_features, process_group, bias, device, dtype, weight_scale) torch.distributed.broadcast(self.weight, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], process_group) if bias: @@ -102,7 +114,10 @@ class RewardModelLinear(ScaleColumnParallelLinear): # we do an all_gather of x before doing the matmul. # If not, then the input is already gathered. if self.weight_scale != 1: - weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach() + weight = ( + self.weight * self.weight_scale + + (1 - self.weight_scale) * self.weight.detach() # pylint: disable=not-callable + ) else: weight = self.weight return fused_dense_func_torch( diff --git a/internlm/model/modeling_moe.py b/internlm/model/modeling_moe.py index 43489bc..fa53d27 100644 --- a/internlm/model/modeling_moe.py +++ b/internlm/model/modeling_moe.py @@ -331,6 +331,7 @@ class PackedFlashInternLm1D(nn.Module): moe_use_rts (bool, optional): default=True, whether to use Random Token Selection. moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE (https://arxiv.org/abs/2201.05596) layer. + tie_embeddings_and_output_weights: embedding and output layer share the same weight. """ def __init__( @@ -370,9 +371,15 @@ class PackedFlashInternLm1D(nn.Module): moe_drop_tokens: bool = True, moe_use_rts: bool = True, moe_use_residual: bool = False, + tie_embeddings_and_output_weights: bool = True, ): super().__init__() + assert not ( + embed_split_hidden and tie_embeddings_and_output_weights + ), "shared embedding weights is not supported when embed_split_hidden is True." + self.tie_embeddings_and_output_weights = tie_embeddings_and_output_weights + checkpoint_layer_num = int(num_layers * checkpoint) if is_reward: @@ -446,6 +453,7 @@ class PackedFlashInternLm1D(nn.Module): device=device, dtype=dtype, weight_scale=embed_grad_scale, + skip_weight_alloction=self.tie_embeddings_and_output_weights, ) for _, param in self.head.named_parameters(): normal_(std=0.0052)(param) @@ -453,6 +461,9 @@ class PackedFlashInternLm1D(nn.Module): setattr(param, IS_TENSOR_PARALLEL, True) self.parallel_output = parallel_output + if self.tie_embeddings_and_output_weights: + self.initialize_word_embeddings(hidden_size, vocab_size, dtype, device) + def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None): # attention_mask: compute attention on the places where the value is 1 # old condition may fail when use shared embedding @@ -491,12 +502,79 @@ class PackedFlashInternLm1D(nn.Module): if hasattr(self, "norm"): hidden_states = self.norm(hidden_states.float()) if hasattr(self, "head"): - hidden_states = self.head(hidden_states) + if self.tie_embeddings_and_output_weights: + hidden_states = self.head(hidden_states, self.shared_embedding_weight()) + else: + hidden_states = self.head(hidden_states) if not self.parallel_output: hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1) return hidden_states, moe_losses + def shared_embedding_weight(self): + if not self.tie_embeddings_and_output_weights: + raise Exception( + "shared_embedding_weight() called for last stage, but share_embeddings_and_output_weights is false" + ) + assert isinstance(self.embedding, ParallelGPT2Embeddings) + + return self.embedding.word_embeddings.weight + + # TODO: refactor code + def initialize_word_embeddings( + self, + hidden_size: int = 768, + vocab_size: int = 50304, + dtype: torch.dtype = torch.float, + device: Optional[torch.device] = None, + ): + if not self.tie_embeddings_and_output_weights: + raise Exception("initialize_word_embeddings() was called but " "tie_embeddings_and_output_weights is false") + + # This function just initializes the word embeddings in the final stage + # when we are using pipeline parallelism. Nothing to do if we aren't + # using pipeline parallelism. + if gpc.get_world_size(ParallelMode.PIPELINE) == 1: + return + + # Parameters are shared between the word embeddings layers, and the + # heads at the end of the model. In a pipelined setup with more than + # one stage, the initial embedding layer and the head are on different + # workers, so we do the following: + # 1. Create a second copy of word_embeddings on the last stage, with + # initial parameters of 0.0. + # 2. Do an all-reduce between the first and last stage to ensure that + # the two copies of word_embeddings start off with the same + # parameter values. + # 3. In the training loop, before an all-reduce between the grads of + # the two word_embeddings layers to ensure that every applied weight + # update is the same on both stages. + if gpc.is_pipeline_last_stage(): + assert not gpc.is_pipeline_first_stage() + # set word_embeddings weights to 0 here, then copy first + # stage's weights using all_reduce below. + self.embedding = ParallelGPT2Embeddings( + embed_dim=hidden_size, + vocab_size=vocab_size, + max_position_embeddings=-1, + process_group=gpc.get_group(ParallelMode.TENSOR), + padding_idx=None, + sequence_parallel=gpc.config.parallel.sequence_parallel, + device=device, + dtype=dtype, + ) + for _, param in self.embedding.named_parameters(): + if gpc.get_world_size(ParallelMode.TENSOR) > 1: + setattr(param, IS_TENSOR_PARALLEL, True) + self.shared_embedding_weight().data.fill_(0) + + # Ensure that first and last stages have the same initial parameter + # values. + if gpc.is_pipeline_first_stage() or gpc.is_pipeline_last_stage(): + torch.distributed.all_reduce( + self.shared_embedding_weight().data, group=gpc.get_group(ParallelMode.EMBEDDING) + ) + def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"), **kwargs): """ @@ -572,6 +650,7 @@ def build_model_with_moe_cfg( moe_drop_tokens: bool = True, moe_use_rts: bool = True, moe_use_residual: bool = False, + tie_embeddings_and_output_weights=False, ): """ Build model with config. @@ -613,6 +692,7 @@ def build_model_with_moe_cfg( moe_use_rts (bool, optional): default=True, whether to use Random Token Selection. moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE (https://arxiv.org/abs/2201.05596) layer. + tie_embeddings_and_output_weights: embedding and output layer share the same weight. """ cfg = dict( @@ -646,6 +726,7 @@ def build_model_with_moe_cfg( moe_drop_tokens=moe_drop_tokens, moe_use_rts=moe_use_rts, moe_use_residual=moe_use_residual, + tie_embeddings_and_output_weights=tie_embeddings_and_output_weights, ) return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg)