diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py index 83f4b97ff..07a7f6cbf 100644 --- a/colossalai/shardformer/modeling/command.py +++ b/colossalai/shardformer/modeling/command.py @@ -3,13 +3,18 @@ import warnings from typing import List, Optional, Tuple, Union import torch -import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from transformers.models.cohere.modeling_cohere import CohereForCausalLM, CohereModel, StaticCache, apply_rotary_pos_emb, repeat_kv +from transformers.models.cohere.modeling_cohere import ( + CohereForCausalLM, + CohereModel, + StaticCache, + apply_rotary_pos_emb, + repeat_kv, +) from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager @@ -584,6 +589,7 @@ def get_command_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz return forward + def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): from transformers import CohereForCausalLM @@ -683,4 +689,4 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): attentions=outputs.attentions, ) - return forward \ No newline at end of file + return forward diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 0c04f7d38..c11ed99ac 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -67,7 +67,7 @@ class BertPolicy(Policy): else: norm_cls = col_nn.LayerNorm - sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None + sp_mode = self.shard_config.sequence_parallelism_mode or None assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for Bert" if sp_mode == "ring": warnings.warn( diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 724a6b77c..20a75cf90 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -50,7 +50,7 @@ class BloomPolicy(Policy): else: norm_cls = col_nn.LayerNorm - sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None + sp_mode = self.shard_config.sequence_parallelism_mode or None assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for BLOOM" if sp_mode == "ring": warnings.warn( diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index 4baf89f6a..01aa77e57 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -57,7 +57,7 @@ class ChatGLMPolicy(Policy): else: norm_cls = col_nn.LayerNorm - sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None + sp_mode = self.shard_config.sequence_parallelism_mode or None assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for ChatGLM2" if sp_mode == "ring": warnings.warn( diff --git a/colossalai/shardformer/policies/command.py b/colossalai/shardformer/policies/command.py index 77f96e462..902baf2e1 100644 --- a/colossalai/shardformer/policies/command.py +++ b/colossalai/shardformer/policies/command.py @@ -73,11 +73,9 @@ class CommandPolicy(Policy): warnings.warn( f"For Command, sequence parallelism is currently not compatible with pipeline parallelism, set to be False" ) - sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None - sp_size = self.shard_config.sequence_parallel_size if self.shard_config.enable_sequence_parallelism else None - sp_group = ( - self.shard_config.sequence_parallel_process_group if self.shard_config.enable_sequence_parallelism else None - ) + sp_mode = self.shard_config.sequence_parallelism_mode or None + sp_size = self.shard_config.sequence_parallel_size or None + sp_group = self.shard_config.sequence_parallel_process_group or None sp_partial_derived = sp_mode in ["split_gather", "ring"] if sp_mode == "all_to_all": @@ -112,7 +110,6 @@ class CommandPolicy(Policy): target_key=CohereModel, ) - if self.shard_config.enable_tensor_parallelism: assert ( self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 281ea88c2..cfe20000a 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -65,7 +65,7 @@ class GPT2Policy(Policy): else: norm_cls = col_nn.LayerNorm - sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None + sp_mode = self.shard_config.sequence_parallelism_mode or None assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for GPT2" if sp_mode == "ring": warnings.warn( diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 5852713c2..85ec6717d 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -73,11 +73,9 @@ class LlamaPolicy(Policy): warnings.warn( f"For llama, sequence parallelism is currently not compatible with pipeline parallelism, set to be False" ) - sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None - sp_size = self.shard_config.sequence_parallel_size if self.shard_config.enable_sequence_parallelism else None - sp_group = ( - self.shard_config.sequence_parallel_process_group if self.shard_config.enable_sequence_parallelism else None - ) + sp_mode = self.shard_config.sequence_parallelism_mode or None + sp_size = self.shard_config.sequence_parallel_size or None + sp_group = self.shard_config.sequence_parallel_process_group or None sp_partial_derived = sp_mode in ["split_gather", "ring"] if sp_mode == "all_to_all":