[misc] Bypass the huggingface bug to solve the mask mismatch problem (#5991)

pull/6015/head
Haze188 2024-08-15 14:40:26 +08:00 committed by GitHub
parent 4dd03999ec
commit 887d2d579b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 3 additions and 0 deletions

View File

@ -666,6 +666,9 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# TODO: upgrade transformers to 4.44.0 to fix the bug, remove the hard code.
self._use_flash_attention_2 = shard_config.enable_flash_attention
self._use_sdpa = False if shard_config.enable_flash_attention else self._use_sdpa
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None