|
|
|
@ -135,6 +135,8 @@ def llama_model_forward(
|
|
|
|
|
) |
|
|
|
|
sm_scale = 1.0 / (batch.head_dim**0.5) |
|
|
|
|
|
|
|
|
|
norm_output = torch.empty_like(hidden_states) |
|
|
|
|
|
|
|
|
|
for layer_id, decoder_layer in enumerate(self.layers): |
|
|
|
|
hidden_states = decoder_layer( |
|
|
|
|
hidden_states, |
|
|
|
@ -149,12 +151,14 @@ def llama_model_forward(
|
|
|
|
|
cos_sin=cos_sin, |
|
|
|
|
fd_inter_tensor=batch.fd_inter_tensor, |
|
|
|
|
output_tensor=output_tensor, |
|
|
|
|
norm_output=norm_output, |
|
|
|
|
sm_scale=sm_scale, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
if batch.is_prompts: |
|
|
|
|
hidden_states = hidden_states[:, -1, :].unsqueeze(dim=1).contiguous() |
|
|
|
|
hidden_states = self.norm(hidden_states) |
|
|
|
|
norm_output = torch.empty_like(hidden_states) |
|
|
|
|
hidden_states = self.norm(hidden_states.reshape(-1, hidden_states.shape[-1]), norm_output) |
|
|
|
|
|
|
|
|
|
return hidden_states |
|
|
|
|
|
|
|
|
@ -174,6 +178,7 @@ def llama_decoder_layer_forward(
|
|
|
|
|
cos_sin: Tuple[torch.Tensor] = None, |
|
|
|
|
fd_inter_tensor: FDIntermTensors = None, |
|
|
|
|
output_tensor: torch.Tensor = None, |
|
|
|
|
norm_output: torch.Tensor = None, |
|
|
|
|
sm_scale: int = None, |
|
|
|
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: |
|
|
|
|
"""This function will replace the forward function of LlamaDecoderLayer. |
|
|
|
@ -191,11 +196,12 @@ def llama_decoder_layer_forward(
|
|
|
|
|
cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None. |
|
|
|
|
fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for storing intermediate values in flash-decoding. Defaults to None. |
|
|
|
|
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None. |
|
|
|
|
norm_output (torch.Tensor, optional): The mid tensor holds the output of layernorm. Defaults to None. |
|
|
|
|
sm_scale (int, optional): Used for flash attention. Defaults to None. |
|
|
|
|
""" |
|
|
|
|
residual = hidden_states |
|
|
|
|
|
|
|
|
|
hidden_states = self.input_layernorm(hidden_states) |
|
|
|
|
hidden_states = self.input_layernorm(hidden_states.reshape(-1, hidden_states.shape[-1]), norm_output) |
|
|
|
|
# Self Attention |
|
|
|
|
hidden_states = self.self_attn( |
|
|
|
|
hidden_states=hidden_states, |
|
|
|
@ -217,7 +223,7 @@ def llama_decoder_layer_forward(
|
|
|
|
|
|
|
|
|
|
# Fully Connected |
|
|
|
|
residual = hidden_states |
|
|
|
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
|
|
|
hidden_states = self.post_attention_layernorm(hidden_states.reshape(-1, hidden_states.shape[-1]), norm_output) |
|
|
|
|
hidden_states = self.mlp(hidden_states) |
|
|
|
|
hidden_states = residual + hidden_states |
|
|
|
|
|
|
|
|
|