[shardformer] Support customized policy for llamav2 based model with HybridParallelPlugin (#4624)

* Enable policy assignment in HybridPlugin and enable llama policy for llamav2

* Remove Policy from Plugin

* revert changes of plugin

HybridParallelModule

* revert changes in plugin

* upgrade transformers

* revert transformers version

---------

Co-authored-by: flybird11111 <1829166702@qq.com>
pull/4171/head^2
eric8607242 2023-09-07 10:15:13 +08:00 committed by GitHub
parent 9709b8f502
commit c3d5fa3bac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 12 additions and 6 deletions

View File

@ -40,14 +40,20 @@ class LlamaPolicy(Policy):
self.shard_config.enable_sequence_parallelism = False
warnings.warn("Llama dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
if self.shard_config.enable_tensor_parallelism:
policy[LlamaDecoderLayer] = ModulePolicyDescription(
attribute_replacement={
decoder_attribute_replacement = {
"self_attn.hidden_size":
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads":
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
},
}
if getattr(self.model.config, "num_key_value_heads", False):
decoder_attribute_replacement["self_attn.num_key_value_heads"] = \
self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size
policy[LlamaDecoderLayer] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",