From a83a2336e827afecb97db4171d3ca7b1e0fe8603 Mon Sep 17 00:00:00 2001 From: GuangyaoZhang Date: Tue, 18 Jun 2024 02:56:47 +0000 Subject: [PATCH] rebase master llama change --- colossalai/shardformer/modeling/command.py | 482 +++++++++++---------- colossalai/shardformer/policies/command.py | 91 +--- 2 files changed, 272 insertions(+), 301 deletions(-) diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py index 27021724c..83f4b97ff 100644 --- a/colossalai/shardformer/modeling/command.py +++ b/colossalai/shardformer/modeling/command.py @@ -1,13 +1,15 @@ import math +import warnings from typing import List, Optional, Tuple, Union import torch +import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from transformers.models.cohere.modeling_cohere import CohereForCausalLM, CohereModel, StaticCache, repeat_kv +from transformers.models.cohere.modeling_cohere import CohereForCausalLM, CohereModel, StaticCache, apply_rotary_pos_emb, repeat_kv from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager @@ -333,6 +335,255 @@ class CommandPipelineForwards: return {"hidden_states": hidden_states} +def get_command_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + if sp_mode is not None: + assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode" + assert (sp_size is not None) and ( + sp_group is not None + ), "Must specify sp_size and sp_group for sequence parallel" + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + bsz, q_len, _ = hidden_states.size() + # sp: modify sp_len when sequence parallel mode is ring + if sp_mode in ["split_gather", "ring"]: + q_len *= sp_size + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # sp: all-to-all comminucation when introducing sequence parallel + if sp_mode == "all_to_all": + query_states = all_to_all_comm(query_states, sp_group) + key_states = all_to_all_comm(key_states, sp_group) + value_states = all_to_all_comm(value_states, sp_group) + bsz, q_len, _ = query_states.size() + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if shard_config.enable_flash_attention: + assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." + attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) + else: + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + # sp: all-to-all comminucation when introducing sequence parallel + if sp_mode == "all_to_all": + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) + attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) + else: + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + return attn_output, attn_weights, past_key_value + + return forward + + +def get_command_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): + logger = logging.get_logger(__name__) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + past_seen_tokens = 0 + seq_len = inputs_embeds.shape[1] + if use_cache: # kept for BC (cache positions) + if not isinstance(past_key_values, StaticCache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_seen_tokens = past_key_values.get_seq_length() + if cache_position is None: + if isinstance(past_key_values, StaticCache): + raise ValueError("cache_position is a required argument when using StaticCache.") + cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_len, device=inputs_embeds.device) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # in this case, attention_mask is a dict rather than a tensor + if shard_config.enable_flash_attention: + mask_shape = (inputs_embeds.shape[0], 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len) + attention_mask = ColoAttention.prepare_attn_kwargs( + mask_shape, + inputs_embeds.dtype, + inputs_embeds.device, + q_padding_mask=attention_mask, + is_causal=True, + ) + else: + attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) + + if sp_mode in ["ring", "split_gather"]: + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) + elif sp_mode == "all_to_all": + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + if sp_mode == "ring" or sp_mode == "split_gather": + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) + elif sp_mode == "all_to_all": + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = ( + next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache + ) + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + return forward + def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): from transformers import CohereForCausalLM @@ -432,231 +683,4 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): attentions=outputs.attentions, ) - return forward - - -def get_command_seq_parallel_attention_forward(sp_mode, sp_size, sp_group, use_flash_attention): - from transformers.models.cohere.modeling_cohere import apply_rotary_pos_emb - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - # sp: modify sp_len when sequence parallel mode is ring - if sp_mode in ["split_gather", "ring"]: - q_len *= sp_size - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # sp: all-to-all comminucation when introducing sequence parallel - if sp_mode == "all_to_all": - query_states = all_to_all_comm(query_states, sp_group) - key_states = all_to_all_comm(key_states, sp_group) - value_states = all_to_all_comm(value_states, sp_group) - bsz, q_len, _ = query_states.size() - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - past_key_value = getattr(self, "past_key_value", past_key_value) - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - if use_flash_attention: - assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." - attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - else: - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) - # sp: all-to-all comminucation when introducing sequence parallel - if sp_mode == "all_to_all": - attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) - attn_output = self.o_proj(attn_output) - - if not output_attentions or use_flash_attention: - attn_weights = None - return attn_output, attn_weights, past_key_value - - return forward - - -def get_command_seq_parallel_model_forward(sp_mode, sp_size, sp_group, use_flash_attention): - logger = logging.get_logger(__name__) - - def forward( - self: CohereModel, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time, and must specify either one" - ) - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - past_seen_tokens = 0 - if use_cache: # kept for BC (cache positions) - if not isinstance(past_key_values, StaticCache): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_seen_tokens = past_key_values.get_seq_length() - if cache_position is None: - if isinstance(past_key_values, StaticCache): - raise ValueError("cache_position is a required argument when using StaticCache.") - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) - if position_ids is None: - position_ids = cache_position.unsqueeze(0) - if use_flash_attention: - hidden_states = inputs_embeds - mask_shape = (hidden_states.shape[0], 1, past_seen_tokens, past_seen_tokens) - attention_mask = ColoAttention.prepare_attn_kwargs( - mask_shape, - hidden_states.dtype, - hidden_states.device, - q_padding_mask=attention_mask, - is_causal=True, - ) - else: - attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) - - if sp_mode in ["ring", "split_gather"]: - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) - elif sp_mode == "all_to_all": - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) - - hidden_states = inputs_embeds - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - for idx, decoder_layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, past_key_value=past_key_values, output_attentions=output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - attention_mask, - position_ids, - ) - - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = ( - next_decoder_cache.to_legacy_cache() - if isinstance(next_decoder_cache, Cache) - else next_decoder_cache - ) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - if sp_mode == "ring" or sp_mode == "split_gather": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) - elif sp_mode == "all_to_all": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - return forward + return forward \ No newline at end of file diff --git a/colossalai/shardformer/policies/command.py b/colossalai/shardformer/policies/command.py index 5284c89f0..77f96e462 100644 --- a/colossalai/shardformer/policies/command.py +++ b/colossalai/shardformer/policies/command.py @@ -19,8 +19,8 @@ from colossalai.shardformer.layer import ( from ..modeling.command import ( CommandPipelineForwards, - get_command_seq_parallel_attention_forward, - get_command_seq_parallel_model_forward, + get_command_flash_attention_forward, + get_command_flash_attention_model_forward, get_lm_forward_with_dist_cross_entropy, ) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -80,38 +80,7 @@ class CommandPolicy(Policy): ) sp_partial_derived = sp_mode in ["split_gather", "ring"] - use_flash_attention = self.shard_config.enable_flash_attention - # Currently sp cannot to be used with flashattention - if sp_mode in ["split_gather", "ring", "all_to_all"]: - if use_flash_attention: - warnings.warn( - f"Sequence parallelism mode {sp_mode} need to be used with FlashAttention, will disable FlashAttention automatically." - ) - use_flash_attention = False - - if sp_mode in ["split_gather", "ring"]: - self.append_or_create_method_replacement( - description={ - "forward": get_command_seq_parallel_model_forward( - sp_mode=sp_mode, - sp_size=sp_size, - sp_group=sp_group, - use_flash_attention=use_flash_attention, - ), - }, - policy=policy, - target_key=CohereModel, - ) - self.append_or_create_method_replacement( - description={ - "forward": get_command_seq_parallel_attention_forward( - sp_mode, sp_size, sp_group, use_flash_attention=use_flash_attention - ), - }, - policy=policy, - target_key=attn_cls, - ) - elif sp_mode == "all_to_all": + if sp_mode == "all_to_all": decoder_attribute_replacement = { "num_heads": self.model.config.num_attention_heads // sp_size, } @@ -121,27 +90,28 @@ class CommandPolicy(Policy): policy[attn_cls] = ModulePolicyDescription( attribute_replacement=decoder_attribute_replacement, ) + if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism: self.append_or_create_method_replacement( description={ - "forward": get_command_seq_parallel_attention_forward( - sp_mode, sp_size, sp_group, use_flash_attention=use_flash_attention - ), + "forward": get_command_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group), }, policy=policy, target_key=attn_cls, ) - self.append_or_create_method_replacement( - description={ - "forward": get_command_seq_parallel_model_forward( - sp_mode=sp_mode, - sp_size=sp_size, - sp_group=sp_group, - use_flash_attention=use_flash_attention, - ), - }, - policy=policy, - target_key=CohereModel, - ) + if self.pipeline_stage_manager is None: + self.append_or_create_method_replacement( + description={ + "forward": get_command_flash_attention_model_forward( + self.shard_config, + sp_mode=sp_mode, + sp_size=sp_size, + sp_group=sp_group, + ), + }, + policy=policy, + target_key=CohereModel, + ) + if self.shard_config.enable_tensor_parallelism: assert ( @@ -236,29 +206,6 @@ class CommandPolicy(Policy): target_key=CohereModel, ) - # use flash attention - if use_flash_attention: - self.append_or_create_method_replacement( - description={ - "forward": get_command_seq_parallel_attention_forward( - sp_mode, sp_group, sp_size, use_flash_attention=use_flash_attention - ), - }, - policy=policy, - target_key=attn_cls, - ) - if self.pipeline_stage_manager is None: - # replace Command model forward method - self.append_or_create_method_replacement( - description={ - "forward": get_command_seq_parallel_model_forward( - sp_mode, sp_size, sp_group, use_flash_attention=use_flash_attention - ), - }, - policy=policy, - target_key=CohereModel, - ) - return policy def postprocess(self):