diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index 69ac3ad25..81c3f973f 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -74,6 +74,7 @@ class Linear1D_Col(ParallelModule): process_group: ProcessGroup = None, gather_output: bool = False, seq_parallel: bool = False, + seq_parallel_dim: int = 1, overlap: bool = False, skip_bias_add: bool = False, weight: Optional[Parameter] = None, @@ -87,6 +88,7 @@ class Linear1D_Col(ParallelModule): self.out_features = out_features self.gather_output = gather_output self.seq_parallel = seq_parallel + self.seq_parallel_dim = seq_parallel_dim self.overlap = overlap self.skip_bias_add = skip_bias_add self.device = device @@ -190,7 +192,8 @@ class Linear1D_Col(ParallelModule): bias = self.bias if not self.skip_bias_add else None if self.seq_parallel: output_parallel = linear_gather_forward_reducescatter_backward(input_parallel, self.weight, bias, - self.process_group, True, 1, self.overlap) + self.process_group, True, + self.seq_parallel_dim, self.overlap) else: output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) @@ -236,6 +239,7 @@ class Linear1D_Row(ParallelModule): device: torch.device = None, process_group: ProcessGroup = None, seq_parallel: bool = False, + seq_parallel_dim: int = 1, parallel_input: bool = True, skip_bias_add: bool = False, weight: Optional[Parameter] = None, @@ -254,6 +258,7 @@ class Linear1D_Row(ParallelModule): self.skip_bias_add = skip_bias_add self.process_group = process_group self.seq_parallel = seq_parallel + self.seq_parallel_dim = seq_parallel_dim self.num_partitions = dist.get_world_size(self.process_group) if skip_bias_add and not bias: @@ -390,7 +395,8 @@ class Linear1D_Row(ParallelModule): else: output_parallel = F.linear(input_, self.weight) if self.seq_parallel: - output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, 1) + output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, + self.seq_parallel_dim) else: output = reduce_forward(output_parallel, self.process_group) diff --git a/colossalai/shardformer/modeling/chatglm2.py b/colossalai/shardformer/modeling/chatglm2.py index 409e2e1f5..16dcf87c8 100644 --- a/colossalai/shardformer/modeling/chatglm2.py +++ b/colossalai/shardformer/modeling/chatglm2.py @@ -9,6 +9,8 @@ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutpu from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer import ShardConfig +from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( ChatGLMForConditionalGeneration, @@ -146,6 +148,7 @@ class ChatGLMPipelineForwards: stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, ): logger = logging.get_logger(__name__) output_hidden_states = (output_hidden_states @@ -198,6 +201,11 @@ class ChatGLMPipelineForwards: all_self_attentions = None all_hidden_states = () if output_hidden_states else None start_idx, end_idx = stage_index[0], stage_index[1] + + if shard_config.enable_sequence_parallelism: + hidden_states = split_forward_gather_backward(hidden_states, + dim=0, + process_group=shard_config.tensor_parallel_process_group) for idx in range(start_idx, end_idx): layer = self.encoder._get_layer(idx) if output_hidden_states: @@ -214,6 +222,11 @@ class ChatGLMPipelineForwards: hidden_states, kv_cache = layer_ret if use_cache: presents = presents + (kv_cache,) + + if shard_config.enable_sequence_parallelism: + hidden_states = gather_forward_split_backward(hidden_states, + dim=0, + process_group=shard_config.tensor_parallel_process_group) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if stage_manager.is_last_stage(): @@ -233,23 +246,22 @@ class ChatGLMPipelineForwards: return {'hidden_states': hidden_states} @staticmethod - def chatglm_for_conditional_generation_forward( - self: ChatGLMForConditionalGeneration, - input_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - return_last_logit: Optional[bool] = False, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - ): + def chatglm_for_conditional_generation_forward(self: ChatGLMForConditionalGeneration, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + return_last_logit: Optional[bool] = False, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None): logger = logging.get_logger(__name__) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) @@ -266,6 +278,7 @@ class ChatGLMPipelineForwards: stage_manager=stage_manager, hidden_states=hidden_states, stage_index=stage_index, + shard_config=shard_config, ) if stage_manager.is_last_stage(): hidden_states = transformer_outputs[0] @@ -296,3 +309,91 @@ class ChatGLMPipelineForwards: ) else: return transformer_outputs + + +def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig): + + def forward( + self, + input_ids, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + full_attention_mask: Optional[torch.BoolTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) + + batch_size, seq_length = input_ids.shape + + if inputs_embeds is None: + inputs_embeds = self.embedding(input_ids) + + if self.pre_seq_len is not None: + if past_key_values is None: + past_key_values = self.get_prompt( + batch_size=batch_size, + device=input_ids.device, + dtype=inputs_embeds.dtype, + ) + if attention_mask is not None: + attention_mask = torch.cat( + [ + attention_mask.new_ones((batch_size, self.pre_seq_len)), + attention_mask, + ], + dim=-1, + ) + + if full_attention_mask is None: + if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): + full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) + + # Rotary positional embeddings + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + if position_ids is not None: + rotary_pos_emb = rotary_pos_emb[position_ids] + else: + rotary_pos_emb = rotary_pos_emb[None, :seq_length] + rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() + + # Run encoder. + # [seq_len, batch_size, hidden_size] -> [seq_len/TP_size, batch_size, hidden_size] + inputs_embeds = split_forward_gather_backward(inputs_embeds, + dim=0, + process_group=shard_config.tensor_parallel_process_group) + hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( + inputs_embeds, + full_attention_mask, + rotary_pos_emb=rotary_pos_emb, + kv_caches=past_key_values, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + ) + + hidden_states = gather_forward_split_backward(hidden_states, + dim=0, + process_group=shard_config.tensor_parallel_process_group) + + if not return_dict: + return tuple(v for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + ] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + return forward diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index fe091c658..19dd95fd6 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -155,20 +155,26 @@ class BertPolicy(Policy): # use flash attention if self.shard_config.enable_flash_attention: - policy[BertSelfAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_bert_flash_attention_forward(), - }) + }, + policy=policy, + target_key=BertSelfAttention) # use jit operator if self.shard_config.enable_jit_fused: - policy[BertSelfOutput] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_bert_self_output_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) - policy[BertOutput] = ModulePolicyDescription(method_replacement={ + }, + policy=policy, + target_key=BertSelfOutput) + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_bert_output_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) + }, + policy=policy, + target_key=BertOutput) return policy diff --git a/colossalai/shardformer/policies/blip2.py b/colossalai/shardformer/policies/blip2.py index 3610e2c41..2e5388ab0 100644 --- a/colossalai/shardformer/policies/blip2.py +++ b/colossalai/shardformer/policies/blip2.py @@ -285,21 +285,26 @@ class BlipPolicy(Policy): # use flash attention if self.shard_config.enable_flash_attention: - policy[Blip2Attention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_blip2_flash_attention_forward(), - }) + }, + policy=policy, + target_key=Blip2Attention) # use jit operator if self.shard_config.enable_jit_fused: - policy[Blip2QFormerSelfOutput] = ModulePolicyDescription( - method_replacement={ - 'forward': get_jit_fused_blip2_QFormer_self_output_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }) - policy[Blip2QFormerOutput] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ + 'forward': get_jit_fused_blip2_QFormer_self_output_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=Blip2QFormerSelfOutput) + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_blip2_QFormer_output_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) + }, + policy=policy, + target_key=Blip2QFormerOutput) return policy diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 2727272d0..21db13f6e 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -125,25 +125,33 @@ class BloomPolicy(Policy): target_key=BloomModel) if self.shard_config.enable_flash_attention: - policy[BloomAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_bloom_flash_attention_forward(), - 'dropout_add': get_dropout_add_func() - }) + 'dropout_add': get_dropout_add_func(), + }, + policy=policy, + target_key=BloomAttention) # enable jit fused operator if self.shard_config.enable_jit_fused: - policy[BloomAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_bloom_attention_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) - policy[BloomMLP] = ModulePolicyDescription(method_replacement={ + }, + policy=policy, + target_key=BloomAttention) + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_bloom_mlp_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) - policy[BloomGelu] = ModulePolicyDescription(method_replacement={ + }, + policy=policy, + target_key=BloomMLP) + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_bloom_gelu_forward(), 'bloom_gelu_forward': get_jit_fused_gelu_forward_func(), - }) + }, + policy=policy, + target_key=BloomGelu) return policy diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index a15aa856d..b0d684a67 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -15,7 +15,11 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( GLMBlock, ) -from ..modeling.chatglm2 import get_flash_core_attention_forward, get_jit_fused_glm_block_forward +from ..modeling.chatglm2 import ( + get_chatglm_sequence_parallel_forward_fn, + get_flash_core_attention_forward, + get_jit_fused_glm_block_forward, +) from ..modeling.jit import get_jit_fused_dropout_add_func from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -45,8 +49,8 @@ class ChatGLMPolicy(Policy): policy = {} + use_sequence_parallel = self.shard_config.enable_sequence_parallelism if self.shard_config.enable_tensor_parallelism: - policy[ChatGLMModel] = ModulePolicyDescription(attribute_replacement={}, sub_module_replacement=[ SubModuleReplacementDescription( @@ -55,36 +59,42 @@ class ChatGLMPolicy(Policy): ) ]) - policy[GLMBlock] = ModulePolicyDescription(attribute_replacement={ - "self_attention.num_attention_heads_per_partition": - self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - "self_attention.projection_size": - (self.model.config.kv_channels * self.model.config.num_attention_heads) // - self.shard_config.tensor_parallel_size, - "self_attention.qkv_hidden_size": - (self.model.config.kv_channels * self.model.config.num_attention_heads * 3) // - self.shard_config.tensor_parallel_size, - "self_attention.core_attention.num_attention_heads_per_partition": - self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - "self_attention.core_attention.hidden_size_per_partition": - self.model.config.kv_channels * self.model.config.num_attention_heads // - self.shard_config.tensor_parallel_size, - }, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="self_attention.query_key_value", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="self_attention.dense", - target_module=col_nn.Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="self_attention.core_attention.attention_dropout", - target_module=col_nn.DropoutForParallelInput, - ), - ]) + policy[GLMBlock] = ModulePolicyDescription( + attribute_replacement={ + "self_attention.num_attention_heads_per_partition": + self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + "self_attention.projection_size": + (self.model.config.kv_channels * self.model.config.num_attention_heads) // + self.shard_config.tensor_parallel_size, + "self_attention.qkv_hidden_size": + (self.model.config.kv_channels * self.model.config.num_attention_heads * 3) // + self.shard_config.tensor_parallel_size, + "self_attention.core_attention.num_attention_heads_per_partition": + self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + "self_attention.core_attention.hidden_size_per_partition": + self.model.config.kv_channels * self.model.config.num_attention_heads // + self.shard_config.tensor_parallel_size, + }, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription(suffix="self_attention.query_key_value", + target_module=col_nn.Linear1D_Col, + kwargs={ + 'seq_parallel': use_sequence_parallel, + 'seq_parallel_dim': 0 + }), + SubModuleReplacementDescription(suffix="self_attention.dense", + target_module=col_nn.Linear1D_Row, + kwargs={ + 'seq_parallel': use_sequence_parallel, + 'seq_parallel_dim': 0 + }), + SubModuleReplacementDescription( + suffix="self_attention.core_attention.attention_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ]) + # optimization configuration if self.shard_config.enable_fused_normalization: if not self.model.config.rmsnorm: @@ -124,16 +134,27 @@ class ChatGLMPolicy(Policy): # use flash attention if self.shard_config.enable_flash_attention: - policy[CoreAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_flash_core_attention_forward(), - }) + }, + policy=policy, + target_key=CoreAttention) + + # use sequence parallel + if use_sequence_parallel: + self.append_or_create_method_replacement( + description={'forward': get_chatglm_sequence_parallel_forward_fn(self.shard_config)}, + policy=policy, + target_key=ChatGLMModel) # use jit fused operator if self.shard_config.enable_jit_fused: - policy[GLMBlock] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_glm_block_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) + }, + policy=policy, + target_key=GLMBlock) return policy @@ -178,7 +199,13 @@ class ChatGLMPolicy(Policy): layers_per_stage = Policy.distribute_layers(module.num_layers, stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + method_replacement = { + 'forward': + partial(new_forward, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=self.shard_config) + } self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index d34c0ae9f..acae26309 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -118,9 +118,11 @@ class GPT2Policy(Policy): target_key=GPT2Block) if self.shard_config.enable_flash_attention: - policy[GPT2Attention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_gpt2_flash_attention_forward(), - }) + }, + policy=policy, + target_key=GPT2Attention) if self.shard_config.enable_sequence_parallelism: policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)} diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 5ee95f3be..ccf776407 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -105,9 +105,11 @@ class LlamaPolicy(Policy): target_key=LlamaModel) if self.shard_config.enable_flash_attention: - policy[LlamaAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_llama_flash_attention_forward(), - }) + }, + policy=policy, + target_key=LlamaAttention) return policy diff --git a/colossalai/shardformer/policies/sam.py b/colossalai/shardformer/policies/sam.py index b1eba0432..9753d5a73 100644 --- a/colossalai/shardformer/policies/sam.py +++ b/colossalai/shardformer/policies/sam.py @@ -199,12 +199,16 @@ class SamPolicy(Policy): # use flash attention if self.shard_config.enable_flash_attention: - policy[SamAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_sam_flash_attention_forward(), - }) - policy[SamVisionAttention] = ModulePolicyDescription(method_replacement={ + }, + policy=policy, + target_key=SamAttention) + self.append_or_create_method_replacement(description={ 'forward': get_sam_vision_flash_attention_forward(), - }) + }, + policy=policy, + target_key=SamVisionAttention) return policy diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 617720ee7..757bab95f 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -90,16 +90,20 @@ class ViTPolicy(Policy): # use flash attention if self.shard_config.enable_flash_attention: - policy[ViTSelfAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_vit_flash_self_attention_forward(), - }) + }, + policy=policy, + target_key=ViTSelfAttention) # use jit fused operator if self.shard_config.enable_jit_fused: - policy[ViTOutput] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_vit_output_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) + }, + policy=policy, + target_key=ViTOutput) return policy def new_model_class(self): diff --git a/tests/kit/model_zoo/transformers/chatglm2.py b/tests/kit/model_zoo/transformers/chatglm2.py index c6473ee2a..d543df00b 100644 --- a/tests/kit/model_zoo/transformers/chatglm2.py +++ b/tests/kit/model_zoo/transformers/chatglm2.py @@ -12,8 +12,8 @@ from ..registry import ModelAttribute, model_zoo def data_gen(): - input_ids = torch.tensor([[5941, 15, 2670, 3543, 632, 2075]], dtype=torch.int64) - attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]]) + input_ids = torch.tensor([[5941, 15, 2670, 3543, 632, 2075, 632, 2075]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]]) return dict(input_ids=input_ids, attention_mask=attention_mask)