pull/532/head
lijiaxing 2023-12-11 15:31:46 +08:00
parent e57ca246d9
commit a83b02acf4
1 changed files with 2 additions and 34 deletions

View File

@ -782,7 +782,6 @@ class PackedFlashLlama1D(nn.Module):
ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default,
out_head_init_std (float): std used to init output lmhead weight. 0.02 by default,
init_type (str): Initialization type. Use uniform or normal. "normal" by default,
extra_pred_tokens (int): The number of extra output head for multi-token-prediction. 0 by default.
rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
"""
@ -824,7 +823,6 @@ class PackedFlashLlama1D(nn.Module):
ffn_other_init_std: float = 0.02,
out_head_init_std: float = 0.02,
init_type: str = "normal",
extra_pred_tokens: int = 0,
rope_base: int = 10000,
):
super().__init__()
@ -926,31 +924,6 @@ class PackedFlashLlama1D(nn.Module):
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
setattr(param, IS_TENSOR_PARALLEL, True)
if extra_pred_tokens > 0:
self.extra_pred_tokens = extra_pred_tokens
assert not is_reward, "extra_pred_tokens > 0 means using multi token prediction, not implement for RLHF"
self.extra_outputs = nn.ModuleList(
[
head_cls(
in_features=hidden_size,
out_features=vocab_size,
process_group=gpc.get_group(ParallelMode.TENSOR),
bias=False,
device=device,
dtype=dtype,
weight_scale=embed_grad_scale,
)
for _ in range(self.extra_pred_tokens)
]
)
for _, param in self.extra_outputs.named_parameters():
if init_type == "normal":
normal_(std=out_head_init_std)(param)
else:
uniform_(std=out_head_init_std)(param)
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
setattr(param, IS_TENSOR_PARALLEL, True)
self.parallel_output = parallel_output
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
@ -988,10 +961,8 @@ class PackedFlashLlama1D(nn.Module):
if hasattr(self, "norm"):
hidden_states = self.norm(hidden_states.float())
if hasattr(self, "extra_pred_tokens") and self.extra_pred_tokens > 0:
extra_hidden_states_list = [self.extra_outputs[i](hidden_states) for i in range(self.extra_pred_tokens)]
else:
extra_hidden_states_list = None
extra_hidden_states_list = None
if hasattr(self, "output"):
hidden_states = self.output(hidden_states)
@ -1086,7 +1057,6 @@ def build_model_with_cfg(
ffn_other_init_std: float = 0.02,
out_head_init_std: float = 0.02,
init_type: str = "normal",
extra_pred_tokens: int = 0,
rope_base: int = 10000,
):
"""
@ -1130,7 +1100,6 @@ def build_model_with_cfg(
ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default,
out_head_init_std (float): std used to init output lmhead weight. 0.02 by default,
init_type (str): Initialization type. Use uniform or normal. "normal" by default,
extra_pred_tokens (int): The number of extra output head for multi-token-prediction. 0 by default.
rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
"""
if deepnorm:
@ -1167,7 +1136,6 @@ def build_model_with_cfg(
ffn_other_init_std=ffn_other_init_std,
out_head_init_std=out_head_init_std,
init_type=init_type,
extra_pred_tokens=extra_pred_tokens,
rope_base=rope_base,
)