diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py index 85cf551b6..27021724c 100644 --- a/colossalai/shardformer/modeling/command.py +++ b/colossalai/shardformer/modeling/command.py @@ -1,5 +1,4 @@ import math -import warnings from typing import List, Optional, Tuple, Union import torch @@ -334,217 +333,6 @@ class CommandPipelineForwards: return {"hidden_states": hidden_states} -def get_command_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size): - from transformers.models.cohere.modeling_cohere import CohereAttention, apply_rotary_pos_emb, repeat_kv - - def forward( - self: CohereAttention, - hidden_states: torch.Tensor, - attention_mask: Optional[dict] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - bsz, q_len, _ = hidden_states.size() - - if sp_mode in ["split_gather", "ring"]: - q_len *= sp_size - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # sp: all-to-all comminucation when introducing sequence parallel - if sp_mode == "all_to_all": - query_states = all_to_all_comm(query_states, sp_group) - key_states = all_to_all_comm(key_states, sp_group) - value_states = all_to_all_comm(value_states, sp_group) - bsz, q_len, _ = query_states.size() - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." - attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - # sp: all-to-all comminucation when introducing sequence parallel - if sp_mode == "all_to_all": - attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - return forward - - -def get_command_model_forward_for_flash_attn(shard_config: ShardConfig): - logger = logging.get_logger(__name__) - assert shard_config.enable_flash_attention, "Flash Attention is not enabled." - - def forward( - self: CohereModel, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" - ) - - if self.gradient_checkpointing and self.training and use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." - ) - use_cache = False - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - past_seen_tokens = 0 - if use_cache: # kept for BC (cache positions) - if not isinstance(past_key_values, StaticCache): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_seen_tokens = past_key_values.get_seq_length() - if cache_position is None: - if isinstance(past_key_values, StaticCache): - raise ValueError("cache_position is a required argument when using StaticCache.") - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) - - if position_ids is None: - position_ids = cache_position.unsqueeze(0) - - # embed positions - hidden_states = inputs_embeds - - # in this case, attention_mask is a dict rather than a tensor - mask_shape = (hidden_states.shape[0], 1, past_seen_tokens, past_seen_tokens) - attention_mask = ColoAttention.prepare_attn_kwargs( - mask_shape, - hidden_states.dtype, - hidden_states.device, - q_padding_mask=attention_mask, - is_causal=True, - ) - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = None - if use_cache: - next_cache = ( - next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache - ) - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - return forward - - def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): from transformers import CohereForCausalLM @@ -647,7 +435,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): return forward -def get_command_seq_parallel_attention_forward(sp_mode, sp_size, sp_group): +def get_command_seq_parallel_attention_forward(sp_mode, sp_size, sp_group, use_flash_attention): from transformers.models.cohere.modeling_cohere import apply_rotary_pos_emb def forward( @@ -692,41 +480,43 @@ def get_command_seq_parallel_attention_forward(sp_mode, sp_size, sp_group): # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) + if use_flash_attention: + assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." + attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + else: + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) - attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": - attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) - else: - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - attn_output = self.o_proj(attn_output) - if not output_attentions: + if not output_attentions or use_flash_attention: attn_weights = None return attn_output, attn_weights, past_key_value return forward -def get_command_seq_parallel_model_forward(sp_mode, sp_size, sp_group): +def get_command_seq_parallel_model_forward(sp_mode, sp_size, sp_group, use_flash_attention): logger = logging.get_logger(__name__) def forward( @@ -779,8 +569,18 @@ def get_command_seq_parallel_model_forward(sp_mode, sp_size, sp_group): ) if position_ids is None: position_ids = cache_position.unsqueeze(0) - - attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) + if use_flash_attention: + hidden_states = inputs_embeds + mask_shape = (hidden_states.shape[0], 1, past_seen_tokens, past_seen_tokens) + attention_mask = ColoAttention.prepare_attn_kwargs( + mask_shape, + hidden_states.dtype, + hidden_states.device, + q_padding_mask=attention_mask, + is_causal=True, + ) + else: + attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) if sp_mode in ["ring", "split_gather"]: inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) diff --git a/colossalai/shardformer/policies/command.py b/colossalai/shardformer/policies/command.py index e2a367f74..5284c89f0 100644 --- a/colossalai/shardformer/policies/command.py +++ b/colossalai/shardformer/policies/command.py @@ -19,8 +19,6 @@ from colossalai.shardformer.layer import ( from ..modeling.command import ( CommandPipelineForwards, - get_command_flash_attention_forward, - get_command_model_forward_for_flash_attn, get_command_seq_parallel_attention_forward, get_command_seq_parallel_model_forward, get_lm_forward_with_dist_cross_entropy, @@ -95,7 +93,10 @@ class CommandPolicy(Policy): self.append_or_create_method_replacement( description={ "forward": get_command_seq_parallel_model_forward( - sp_mode=sp_mode, sp_size=sp_size, sp_group=sp_group + sp_mode=sp_mode, + sp_size=sp_size, + sp_group=sp_group, + use_flash_attention=use_flash_attention, ), }, policy=policy, @@ -103,7 +104,9 @@ class CommandPolicy(Policy): ) self.append_or_create_method_replacement( description={ - "forward": get_command_seq_parallel_attention_forward(sp_mode, sp_size, sp_group), + "forward": get_command_seq_parallel_attention_forward( + sp_mode, sp_size, sp_group, use_flash_attention=use_flash_attention + ), }, policy=policy, target_key=attn_cls, @@ -120,7 +123,9 @@ class CommandPolicy(Policy): ) self.append_or_create_method_replacement( description={ - "forward": get_command_seq_parallel_attention_forward(sp_mode, sp_size, sp_group), + "forward": get_command_seq_parallel_attention_forward( + sp_mode, sp_size, sp_group, use_flash_attention=use_flash_attention + ), }, policy=policy, target_key=attn_cls, @@ -131,6 +136,7 @@ class CommandPolicy(Policy): sp_mode=sp_mode, sp_size=sp_size, sp_group=sp_group, + use_flash_attention=use_flash_attention, ), }, policy=policy, @@ -234,7 +240,9 @@ class CommandPolicy(Policy): if use_flash_attention: self.append_or_create_method_replacement( description={ - "forward": get_command_flash_attention_forward(self.shard_config, sp_mode, sp_group, sp_size), + "forward": get_command_seq_parallel_attention_forward( + sp_mode, sp_group, sp_size, use_flash_attention=use_flash_attention + ), }, policy=policy, target_key=attn_cls, @@ -243,7 +251,9 @@ class CommandPolicy(Policy): # replace Command model forward method self.append_or_create_method_replacement( description={ - "forward": get_command_model_forward_for_flash_attn(self.shard_config), + "forward": get_command_seq_parallel_model_forward( + sp_mode, sp_size, sp_group, use_flash_attention=use_flash_attention + ), }, policy=policy, target_key=CohereModel,