mirror of https://github.com/InternLM/InternLM
modeling
parent
e57ca246d9
commit
a83b02acf4
|
@ -782,7 +782,6 @@ class PackedFlashLlama1D(nn.Module):
|
|||
ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default,
|
||||
out_head_init_std (float): std used to init output lmhead weight. 0.02 by default,
|
||||
init_type (str): Initialization type. Use uniform or normal. "normal" by default,
|
||||
extra_pred_tokens (int): The number of extra output head for multi-token-prediction. 0 by default.
|
||||
rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
|
||||
"""
|
||||
|
||||
|
@ -824,7 +823,6 @@ class PackedFlashLlama1D(nn.Module):
|
|||
ffn_other_init_std: float = 0.02,
|
||||
out_head_init_std: float = 0.02,
|
||||
init_type: str = "normal",
|
||||
extra_pred_tokens: int = 0,
|
||||
rope_base: int = 10000,
|
||||
):
|
||||
super().__init__()
|
||||
|
@ -926,31 +924,6 @@ class PackedFlashLlama1D(nn.Module):
|
|||
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||
setattr(param, IS_TENSOR_PARALLEL, True)
|
||||
|
||||
if extra_pred_tokens > 0:
|
||||
self.extra_pred_tokens = extra_pred_tokens
|
||||
assert not is_reward, "extra_pred_tokens > 0 means using multi token prediction, not implement for RLHF"
|
||||
self.extra_outputs = nn.ModuleList(
|
||||
[
|
||||
head_cls(
|
||||
in_features=hidden_size,
|
||||
out_features=vocab_size,
|
||||
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||
bias=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
weight_scale=embed_grad_scale,
|
||||
)
|
||||
for _ in range(self.extra_pred_tokens)
|
||||
]
|
||||
)
|
||||
for _, param in self.extra_outputs.named_parameters():
|
||||
if init_type == "normal":
|
||||
normal_(std=out_head_init_std)(param)
|
||||
else:
|
||||
uniform_(std=out_head_init_std)(param)
|
||||
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||
setattr(param, IS_TENSOR_PARALLEL, True)
|
||||
|
||||
self.parallel_output = parallel_output
|
||||
|
||||
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
|
||||
|
@ -988,10 +961,8 @@ class PackedFlashLlama1D(nn.Module):
|
|||
|
||||
if hasattr(self, "norm"):
|
||||
hidden_states = self.norm(hidden_states.float())
|
||||
if hasattr(self, "extra_pred_tokens") and self.extra_pred_tokens > 0:
|
||||
extra_hidden_states_list = [self.extra_outputs[i](hidden_states) for i in range(self.extra_pred_tokens)]
|
||||
else:
|
||||
extra_hidden_states_list = None
|
||||
|
||||
extra_hidden_states_list = None
|
||||
if hasattr(self, "output"):
|
||||
hidden_states = self.output(hidden_states)
|
||||
|
||||
|
@ -1086,7 +1057,6 @@ def build_model_with_cfg(
|
|||
ffn_other_init_std: float = 0.02,
|
||||
out_head_init_std: float = 0.02,
|
||||
init_type: str = "normal",
|
||||
extra_pred_tokens: int = 0,
|
||||
rope_base: int = 10000,
|
||||
):
|
||||
"""
|
||||
|
@ -1130,7 +1100,6 @@ def build_model_with_cfg(
|
|||
ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default,
|
||||
out_head_init_std (float): std used to init output lmhead weight. 0.02 by default,
|
||||
init_type (str): Initialization type. Use uniform or normal. "normal" by default,
|
||||
extra_pred_tokens (int): The number of extra output head for multi-token-prediction. 0 by default.
|
||||
rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
|
||||
"""
|
||||
if deepnorm:
|
||||
|
@ -1167,7 +1136,6 @@ def build_model_with_cfg(
|
|||
ffn_other_init_std=ffn_other_init_std,
|
||||
out_head_init_std=out_head_init_std,
|
||||
init_type=init_type,
|
||||
extra_pred_tokens=extra_pred_tokens,
|
||||
rope_base=rope_base,
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue