diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index 559f9a56f..b1573ae16 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -114,7 +114,7 @@ We will follow this roadmap to develop Shardformer: | bloom | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | | chatglm2 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | | vit | [x] | [x] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] | -| whisper | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] | +| whisper | [x] | [x] | [x] | [x] | [x] | [ ] | [x] | [ ] | [ ] | | sam | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] | | blip2 | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] | | roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index 5d496f08e..31ba82166 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -57,6 +57,11 @@ class WhisperPolicy(Policy): warnings.warn( "Whisper dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + #TODO using the jit fused add_and_dropout affect the accuracy + if self.shard_config.enable_jit_fused: + self.shard_config.enable_jit_fused = False + warnings.warn("Whisper dosen't support jit fused operator now, will ignore the jit fused operator flag.") + if self.shard_config.enable_tensor_parallelism: policy[WhisperEncoderLayer] = ModulePolicyDescription(attribute_replacement={ "self_attn.embed_dim":