mirror of https://github.com/hpcaitech/ColossalAI
[Shardformer] add assert for num of attention heads divisible by tp_size (#5670)
* add assert for num of attention heads divisible by tp_size * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/5541/head
parent
6af6d6fc9f
commit
d3f34ee8cc
|
@ -79,6 +79,9 @@ class BertPolicy(Policy):
|
||||||
sp_partial_derived = sp_mode == "split_gather"
|
sp_partial_derived = sp_mode == "split_gather"
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
assert (
|
||||||
|
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||||
|
), f"The number of attention heads must be divisible by tensor parallel size."
|
||||||
policy[BertLayer] = ModulePolicyDescription(
|
policy[BertLayer] = ModulePolicyDescription(
|
||||||
attribute_replacement={
|
attribute_replacement={
|
||||||
"attention.self.all_head_size": self.model.config.hidden_size
|
"attention.self.all_head_size": self.model.config.hidden_size
|
||||||
|
|
|
@ -52,6 +52,9 @@ class BlipPolicy(Policy):
|
||||||
norm_cls = col_nn.LayerNorm
|
norm_cls = col_nn.LayerNorm
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
assert (
|
||||||
|
self.model.config.vision_config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||||
|
), f"The number of attention heads must be divisible by tensor parallel size."
|
||||||
policy[Blip2EncoderLayer] = ModulePolicyDescription(
|
policy[Blip2EncoderLayer] = ModulePolicyDescription(
|
||||||
attribute_replacement={
|
attribute_replacement={
|
||||||
"self_attn.num_heads": self.model.config.vision_config.num_attention_heads
|
"self_attn.num_heads": self.model.config.vision_config.num_attention_heads
|
||||||
|
|
|
@ -61,6 +61,9 @@ class BloomPolicy(Policy):
|
||||||
sp_partial_derived = sp_mode == "split_gather"
|
sp_partial_derived = sp_mode == "split_gather"
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
assert (
|
||||||
|
self.model.config.n_head % self.shard_config.tensor_parallel_size == 0
|
||||||
|
), f"The number of attention heads must be divisible by tensor parallel size."
|
||||||
policy[BloomBlock] = ModulePolicyDescription(
|
policy[BloomBlock] = ModulePolicyDescription(
|
||||||
attribute_replacement={
|
attribute_replacement={
|
||||||
"self_attention.hidden_size": self.model.config.hidden_size
|
"self_attention.hidden_size": self.model.config.hidden_size
|
||||||
|
|
|
@ -47,6 +47,12 @@ class FalconPolicy(Policy):
|
||||||
embedding_cls = col_nn.PaddingEmbedding
|
embedding_cls = col_nn.PaddingEmbedding
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
assert (
|
||||||
|
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||||
|
), f"The number of attention heads must be divisible by tensor parallel size."
|
||||||
|
assert (
|
||||||
|
self.model.config.num_kv_heads % self.shard_config.tensor_parallel_size == 0
|
||||||
|
), f"The number of key_value heads must be divisible by tensor parallel size."
|
||||||
attn_attribute_replacement = {
|
attn_attribute_replacement = {
|
||||||
"self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
"self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||||
"self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
"self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||||
|
|
|
@ -84,6 +84,9 @@ class GPT2Policy(Policy):
|
||||||
self.shard_config.enable_flash_attention = False
|
self.shard_config.enable_flash_attention = False
|
||||||
use_flash_attention = False
|
use_flash_attention = False
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
assert (
|
||||||
|
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||||
|
), f"The number of attention heads must be divisible by tensor parallel size."
|
||||||
policy[GPT2Model] = ModulePolicyDescription(
|
policy[GPT2Model] = ModulePolicyDescription(
|
||||||
sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
|
|
|
@ -57,6 +57,9 @@ class GPTJPolicy(Policy):
|
||||||
|
|
||||||
overlap = self.shard_config.enable_sequence_overlap
|
overlap = self.shard_config.enable_sequence_overlap
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
assert (
|
||||||
|
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||||
|
), f"The number of attention heads must be divisible by tensor parallel size."
|
||||||
policy[GPTJModel] = ModulePolicyDescription(
|
policy[GPTJModel] = ModulePolicyDescription(
|
||||||
sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
|
|
|
@ -138,6 +138,12 @@ class LlamaPolicy(Policy):
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
assert (
|
||||||
|
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||||
|
), f"The number of attention heads must be divisible by tensor parallel size."
|
||||||
|
assert (
|
||||||
|
self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
|
||||||
|
), f"The number of key_value heads must be divisible by tensor parallel size."
|
||||||
decoder_attribute_replacement = {
|
decoder_attribute_replacement = {
|
||||||
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||||
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||||
|
|
|
@ -66,6 +66,12 @@ class MistralPolicy(Policy):
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
assert (
|
||||||
|
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||||
|
), f"The number of attention heads must be divisible by tensor parallel size."
|
||||||
|
assert (
|
||||||
|
self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
|
||||||
|
), f"The number of key_value heads must be divisible by tensor parallel size."
|
||||||
decoder_attribute_replacement = {
|
decoder_attribute_replacement = {
|
||||||
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||||
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||||
|
|
|
@ -76,6 +76,9 @@ class OPTPolicy(Policy):
|
||||||
warnings.warn("OPT doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
warnings.warn("OPT doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
assert (
|
||||||
|
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||||
|
), f"The number of attention heads must be divisible by tensor parallel size."
|
||||||
policy[OPTDecoderLayer] = ModulePolicyDescription(
|
policy[OPTDecoderLayer] = ModulePolicyDescription(
|
||||||
sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
|
|
|
@ -31,6 +31,9 @@ class SamPolicy(Policy):
|
||||||
norm_cls = col_nn.LayerNorm
|
norm_cls = col_nn.LayerNorm
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
assert (
|
||||||
|
self.model.config.vision_config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||||
|
), f"The number of attention heads must be divisible by tensor parallel size."
|
||||||
policy[SamVisionLayer] = ModulePolicyDescription(
|
policy[SamVisionLayer] = ModulePolicyDescription(
|
||||||
attribute_replacement={
|
attribute_replacement={
|
||||||
"attn.num_attention_heads": self.model.config.vision_config.num_attention_heads
|
"attn.num_attention_heads": self.model.config.vision_config.num_attention_heads
|
||||||
|
|
|
@ -72,6 +72,9 @@ class T5BasePolicy(Policy):
|
||||||
warnings.warn("T5 doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
warnings.warn("T5 doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
assert (
|
||||||
|
self.model.config.num_heads % self.shard_config.tensor_parallel_size == 0
|
||||||
|
), f"The number of attention heads must be divisible by tensor parallel size."
|
||||||
policy[T5Stack] = ModulePolicyDescription(
|
policy[T5Stack] = ModulePolicyDescription(
|
||||||
sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
|
|
|
@ -44,6 +44,9 @@ class ViTPolicy(Policy):
|
||||||
warnings.warn("Vit doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
warnings.warn("Vit doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
assert (
|
||||||
|
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||||
|
), f"The number of attention heads must be divisible by tensor parallel size."
|
||||||
policy[ViTEmbeddings] = ModulePolicyDescription(
|
policy[ViTEmbeddings] = ModulePolicyDescription(
|
||||||
attribute_replacement={},
|
attribute_replacement={},
|
||||||
param_replacement=[],
|
param_replacement=[],
|
||||||
|
|
|
@ -78,6 +78,9 @@ class WhisperPolicy(Policy):
|
||||||
warnings.warn("Whisper doesn't support jit fused operator now, will ignore the jit fused operator flag.")
|
warnings.warn("Whisper doesn't support jit fused operator now, will ignore the jit fused operator flag.")
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
assert (
|
||||||
|
self.model.config.encoder_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||||
|
), f"The number of attention heads must be divisible by tensor parallel size."
|
||||||
policy[WhisperEncoderLayer] = ModulePolicyDescription(
|
policy[WhisperEncoderLayer] = ModulePolicyDescription(
|
||||||
attribute_replacement={
|
attribute_replacement={
|
||||||
"self_attn.embed_dim": self.model.config.d_model // self.shard_config.tensor_parallel_size,
|
"self_attn.embed_dim": self.model.config.d_model // self.shard_config.tensor_parallel_size,
|
||||||
|
|
Loading…
Reference in New Issue