[shardformer] to fix whisper test failed due to significant accuracy differences. (#4710)

* [shardformer] fix whisper test failed

* [shardformer] fix whisper test failed

* [shardformer] fix whisper test failed

* [shardformer] fix whisper test failed
pull/4727/head
flybird11111 2023-09-14 21:34:20 +08:00 committed by GitHub
parent e2c0e7f92a
commit 20190b49a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 1 deletions

View File

@ -114,7 +114,7 @@ We will follow this roadmap to develop Shardformer:
| bloom | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] |
| chatglm2 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] |
| vit | [x] | [x] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] |
| whisper | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] |
| whisper | [x] | [x] | [x] | [x] | [x] | [ ] | [x] | [ ] | [ ] |
| sam | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] |
| blip2 | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] |
| roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |

View File

@ -57,6 +57,11 @@ class WhisperPolicy(Policy):
warnings.warn(
"Whisper dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
#TODO using the jit fused add_and_dropout affect the accuracy
if self.shard_config.enable_jit_fused:
self.shard_config.enable_jit_fused = False
warnings.warn("Whisper dosen't support jit fused operator now, will ignore the jit fused operator flag.")
if self.shard_config.enable_tensor_parallelism:
policy[WhisperEncoderLayer] = ModulePolicyDescription(attribute_replacement={
"self_attn.embed_dim":