diff --git a/internlm/model/modeling_llama.py b/internlm/model/modeling_llama.py index a5362c4..16af3c2 100644 --- a/internlm/model/modeling_llama.py +++ b/internlm/model/modeling_llama.py @@ -782,7 +782,6 @@ class PackedFlashLlama1D(nn.Module): ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default, out_head_init_std (float): std used to init output lmhead weight. 0.02 by default, init_type (str): Initialization type. Use uniform or normal. "normal" by default, - extra_pred_tokens (int): The number of extra output head for multi-token-prediction. 0 by default. rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. """ @@ -824,7 +823,6 @@ class PackedFlashLlama1D(nn.Module): ffn_other_init_std: float = 0.02, out_head_init_std: float = 0.02, init_type: str = "normal", - extra_pred_tokens: int = 0, rope_base: int = 10000, ): super().__init__() @@ -926,31 +924,6 @@ class PackedFlashLlama1D(nn.Module): if gpc.get_world_size(ParallelMode.TENSOR) > 1: setattr(param, IS_TENSOR_PARALLEL, True) - if extra_pred_tokens > 0: - self.extra_pred_tokens = extra_pred_tokens - assert not is_reward, "extra_pred_tokens > 0 means using multi token prediction, not implement for RLHF" - self.extra_outputs = nn.ModuleList( - [ - head_cls( - in_features=hidden_size, - out_features=vocab_size, - process_group=gpc.get_group(ParallelMode.TENSOR), - bias=False, - device=device, - dtype=dtype, - weight_scale=embed_grad_scale, - ) - for _ in range(self.extra_pred_tokens) - ] - ) - for _, param in self.extra_outputs.named_parameters(): - if init_type == "normal": - normal_(std=out_head_init_std)(param) - else: - uniform_(std=out_head_init_std)(param) - if gpc.get_world_size(ParallelMode.TENSOR) > 1: - setattr(param, IS_TENSOR_PARALLEL, True) - self.parallel_output = parallel_output def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None): @@ -988,10 +961,8 @@ class PackedFlashLlama1D(nn.Module): if hasattr(self, "norm"): hidden_states = self.norm(hidden_states.float()) - if hasattr(self, "extra_pred_tokens") and self.extra_pred_tokens > 0: - extra_hidden_states_list = [self.extra_outputs[i](hidden_states) for i in range(self.extra_pred_tokens)] - else: - extra_hidden_states_list = None + + extra_hidden_states_list = None if hasattr(self, "output"): hidden_states = self.output(hidden_states) @@ -1086,7 +1057,6 @@ def build_model_with_cfg( ffn_other_init_std: float = 0.02, out_head_init_std: float = 0.02, init_type: str = "normal", - extra_pred_tokens: int = 0, rope_base: int = 10000, ): """ @@ -1130,7 +1100,6 @@ def build_model_with_cfg( ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default, out_head_init_std (float): std used to init output lmhead weight. 0.02 by default, init_type (str): Initialization type. Use uniform or normal. "normal" by default, - extra_pred_tokens (int): The number of extra output head for multi-token-prediction. 0 by default. rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. """ if deepnorm: @@ -1167,7 +1136,6 @@ def build_model_with_cfg( ffn_other_init_std=ffn_other_init_std, out_head_init_std=out_head_init_std, init_type=init_type, - extra_pred_tokens=extra_pred_tokens, rope_base=rope_base, )