From 431b7bcf8f8b1c503261e3ce8815ab25fd8c2f07 Mon Sep 17 00:00:00 2001 From: GuangyaoZhang Date: Thu, 13 Jun 2024 06:47:58 +0000 Subject: [PATCH 1/7] Copy llama to command --- colossalai/shardformer/modeling/command.py | 1051 ++++++++++++++++++++ colossalai/shardformer/policies/command.py | 459 +++++++++ 2 files changed, 1510 insertions(+) create mode 100644 colossalai/shardformer/modeling/command.py create mode 100644 colossalai/shardformer/policies/command.py diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py new file mode 100644 index 000000000..01d10c8dc --- /dev/null +++ b/colossalai/shardformer/modeling/command.py @@ -0,0 +1,1051 @@ +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.cache_utils import Cache +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.models.llama.modeling_llama import ( + LlamaForCausalLM, + LlamaForSequenceClassification, + LlamaModel, + apply_rotary_pos_emb, + repeat_kv, +) +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer._operation import ( + all_to_all_comm, + gather_forward_split_backward, + split_forward_gather_backward, +) +from colossalai.shardformer.shard import ShardConfig + +from ..layer import ColoAttention, cross_entropy_1d + + +class LlamaPipelineForwards: + """ + This class serves as a micro library for forward function substitution of Llama models + under pipeline setting. + """ + + @staticmethod + def llama_model_forward( + self: LlamaModel, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ): + logger = logging.get_logger(__name__) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + device = input_ids.device if input_ids is not None else inputs_embeds.device + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + else: + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + device = hidden_states.device + + seq_length_with_past = seq_length + past_key_values_length = 0 + + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + if use_cache: + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") + use_cache = False + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0) + + # embed positions, for the first stage, hidden_states is the input embeddings, + # for the other stages, hidden_states is the output of the previous stage + if shard_config.enable_flash_attention: + # in this case, attention_mask is a dict rather than a tensor + mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) + attention_mask = ColoAttention.prepare_attn_kwargs( + mask_shape, + hidden_states.dtype, + hidden_states.device, + q_padding_mask=attention_mask, + is_causal=True, + ) + else: + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._use_sdpa and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + hidden_states, + past_key_values_length, + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + start_idx, end_idx = stage_index[0], stage_index[1] + num_ckpt_layers = 0 + if self.gradient_checkpointing and self.training: + num_ckpt_layers = end_idx - start_idx + # TODO: We can replace `gradient_checkpointing_enable` fn and initialize a gradient_checkpointing (List[bool]) for each layer + if shard_config.gradient_checkpoint_config is not None: + num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers( + stage=stage_manager.stage, + num_stages=stage_manager.num_stages, + num_layers=end_idx - start_idx, + model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0), + num_model_chunks=stage_manager.num_model_chunks, + ) + assert num_ckpt_layers <= end_idx - start_idx + + for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if idx - start_idx < num_ckpt_layers: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if stage_manager.is_last_stage(): + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if stage_manager.is_last_stage(): + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_cache, + all_hidden_states, + all_self_attns, + ] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + # always return dict for imediate stage + return {"hidden_states": hidden_states} + + @staticmethod + def llama_for_causal_lm_forward( + self: LlamaForCausalLM, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ): + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + logger = logging.get_logger(__name__) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = LlamaPipelineForwards.llama_model_forward( + self.model, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) + past_key_values = None + + if stage_manager.is_last_stage(): + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + if shard_config.enable_tensor_parallelism and shard_config.parallel_output: + new_vocab_size = logits.shape[-1] + shift_logits = shift_logits.view(-1, new_vocab_size) + loss = cross_entropy_1d( + shift_logits, + shift_labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=self.lm_head.out_features, + dtype=self.model.dtype, + ) + else: + shift_logits = shift_logits.view(-1, self.config.vocab_size) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get("hidden_states") + return {"hidden_states": hidden_states} + + @staticmethod + def llama_for_sequence_classification_forward( + self: LlamaForSequenceClassification, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + logger = logging.get_logger(__name__) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + transformer_outputs = LlamaPipelineForwards.llama_model_forward( + self.model, + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) + + if input_ids is not None: + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + batch_size = inputs_embeds.shape[0] + else: + batch_size = hidden_states.shape[0] + + if stage_manager.is_last_stage(): + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + else: + hidden_states = transformer_outputs.get("hidden_states") + return {"hidden_states": hidden_states} + + +def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size): + from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb + + try: + from transformers.models.llama.modeling_llama import repeat_kv + except: + warnings.warn("using llamav1, llamav1 hasn't repeat_kv function") + + def forward( + self: LlamaAttention, + hidden_states: torch.Tensor, + attention_mask: Optional[dict] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.size() + + if sp_mode in ["split_gather", "ring"]: + q_len *= sp_size + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # sp: all-to-all comminucation when introducing sequence parallel + if sp_mode == "all_to_all": + query_states = all_to_all_comm(query_states, sp_group) + key_states = all_to_all_comm(key_states, sp_group) + value_states = all_to_all_comm(value_states, sp_group) + bsz, q_len, _ = query_states.size() + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." + attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + # sp: all-to-all comminucation when introducing sequence parallel + if sp_mode == "all_to_all": + attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + return forward + + +def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig): + logger = logging.get_logger(__name__) + assert shard_config.enable_flash_attention, "Flash Attention is not enabled." + + def forward( + self: LlamaModel, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + hidden_states = inputs_embeds + + # in this case, attention_mask is a dict rather than a tensor + mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) + attention_mask = ColoAttention.prepare_attn_kwargs( + mask_shape, + hidden_states.dtype, + hidden_states.device, + q_padding_mask=attention_mask, + is_causal=True, + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + return forward + + +def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): + from transformers import LlamaForCausalLM + + def forward( + self: LlamaForCausalLM, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + new_vocab_size = logits.shape[-1] + shift_logits = shift_logits.view(-1, new_vocab_size) + loss = cross_entropy_1d( + shift_logits, + shift_labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=self.lm_head.out_features, + dtype=self.model.dtype, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + return forward + + +def get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + # sp: modify sp_len when sequence parallel mode is ring + if sp_mode in ["split_gather", "ring"]: + q_len *= sp_size + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # sp: all-to-all comminucation when introducing sequence parallel + if sp_mode == "all_to_all": + query_states = all_to_all_comm(query_states, sp_group) + key_states = all_to_all_comm(key_states, sp_group) + value_states = all_to_all_comm(value_states, sp_group) + bsz, q_len, _ = query_states.size() + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + # sp: all-to-all comminucation when introducing sequence parallel + if sp_mode == "all_to_all": + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) + attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) + else: + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + return attn_output, attn_weights, past_key_value + + return forward + + +def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group): + logger = logging.get_logger(__name__) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + # modify past_key_values_length when using sequence parallel + past_key_values_length *= sp_size + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if sp_mode in ["ring", "split_gather"]: + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) + elif sp_mode == "all_to_all": + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) + + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), + dtype=torch.bool, + device=inputs_embeds.device, + ) + + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, attention_mask.shape, inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + + if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + ) + + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + if sp_mode == "ring" or sp_mode == "split_gather": + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) + elif sp_mode == "all_to_all": + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + return forward diff --git a/colossalai/shardformer/policies/command.py b/colossalai/shardformer/policies/command.py new file mode 100644 index 000000000..a9c982231 --- /dev/null +++ b/colossalai/shardformer/policies/command.py @@ -0,0 +1,459 @@ +import warnings +from functools import partial +from typing import Callable, Dict, List, Union + +import torch.nn as nn +from torch import Tensor +from torch.nn import Module + +from colossalai.shardformer.layer import ( + FusedRMSNorm, + Linear1D_Col, + Linear1D_Row, + PaddingEmbedding, + PaddingLMHead, + RMSNorm, + VocabParallelEmbedding1D, + VocabParallelLMHead1D, +) + +from ..modeling.llama import ( + LlamaPipelineForwards, + get_llama_flash_attention_forward, + get_llama_model_forward_for_flash_attn, + get_llama_seq_parallel_attention_forward, + get_llama_seq_parallel_model_forward, + get_lm_forward_with_dist_cross_entropy, +) +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +__all__ = ["LlamaPolicy", "LlamaForCausalLMPolicy", "LlamaForSequenceClassificationPolicy"] + + +class LlamaPolicy(Policy): + def config_sanity_check(self): + pass + + def preprocess(self): + self.tie_weight = self.tie_weight_check() + self.origin_attn_implement = self.model.config._attn_implementation + return self.model + + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaFlashAttention2, + LlamaModel, + LlamaSdpaAttention, + ) + + ATTN_IMPLEMENTATION = { + "eager": LlamaAttention, + "flash_attention_2": LlamaFlashAttention2, + "sdpa": LlamaSdpaAttention, + } + policy = {} + + attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = PaddingEmbedding + + if self.shard_config.enable_fused_normalization: + norm_cls = FusedRMSNorm + else: + norm_cls = RMSNorm + + if self.pipeline_stage_manager is not None: + self.shard_config.enable_sequence_parallelism = False + self.shard_config.enable_sequence_overlap = False + self.shard_config.sequence_parallelism_mode = None + warnings.warn( + f"For llama, sequence parallelism is currently not compatible with pipeline parallelism, set to be False" + ) + sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None + sp_size = self.shard_config.sequence_parallel_size if self.shard_config.enable_sequence_parallelism else None + sp_group = ( + self.shard_config.sequence_parallel_process_group if self.shard_config.enable_sequence_parallelism else None + ) + sp_partial_derived = sp_mode in ["split_gather", "ring"] + + use_flash_attention = self.shard_config.enable_flash_attention + # Currently sp cannot to be used with flashattention + if sp_mode in ["split_gather", "ring", "all_to_all"]: + if use_flash_attention: + warnings.warn( + f"Sequence parallelism mode {sp_mode} need to be used with FlashAttention, will disable FlashAttention automatically." + ) + use_flash_attention = False + + if sp_mode in ["split_gather", "ring"]: + self.append_or_create_method_replacement( + description={ + "forward": get_llama_seq_parallel_model_forward( + sp_mode=sp_mode, sp_size=sp_size, sp_group=sp_group + ), + }, + policy=policy, + target_key=LlamaModel, + ) + self.append_or_create_method_replacement( + description={ + "forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group), + }, + policy=policy, + target_key=attn_cls, + ) + elif sp_mode == "all_to_all": + decoder_attribute_replacement = { + "num_heads": self.model.config.num_attention_heads // sp_size, + } + if getattr(self.model.config, "num_key_value_heads", False): + decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size + + policy[attn_cls] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + ) + self.append_or_create_method_replacement( + description={ + "forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group), + }, + policy=policy, + target_key=attn_cls, + ) + self.append_or_create_method_replacement( + description={ + "forward": get_llama_seq_parallel_model_forward( + sp_mode=sp_mode, + sp_size=sp_size, + sp_group=sp_group, + ), + }, + policy=policy, + target_key=LlamaModel, + ) + + if self.shard_config.enable_tensor_parallelism: + assert ( + self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of attention heads must be divisible by tensor parallel size." + if hasattr(self.model.config, "num_key_value_heads"): + assert ( + self.model.config.num_key_value_heads >= self.shard_config.tensor_parallel_size + and self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of key_value heads must be divisible by, and must not be less than tensor parallel size." + decoder_attribute_replacement = { + "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + } + if getattr(self.model.config, "num_key_value_heads", False): + decoder_attribute_replacement["self_attn.num_key_value_heads"] = ( + self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size + ) + + policy[LlamaDecoderLayer] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=Linear1D_Col, + kwargs=dict(seq_parallel_mode=sp_mode), + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=Linear1D_Col, + kwargs=dict(seq_parallel_mode=sp_mode), + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=Linear1D_Col, + kwargs=dict(seq_parallel_mode=sp_mode), + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=Linear1D_Row, + kwargs=dict(seq_parallel_mode=sp_mode), + ), + SubModuleReplacementDescription( + suffix="mlp.gate_proj", + target_module=Linear1D_Col, + kwargs=dict(seq_parallel_mode=sp_mode), + ), + SubModuleReplacementDescription( + suffix="mlp.up_proj", + target_module=Linear1D_Col, + kwargs=dict(seq_parallel_mode=sp_mode), + ), + SubModuleReplacementDescription( + suffix="mlp.down_proj", + target_module=Linear1D_Row, + kwargs=dict(seq_parallel_mode=sp_mode), + ), + ], + ) + + if embedding_cls is not None: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ), + policy=policy, + target_key=LlamaModel, + ) + + # optimization configuration + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="input_layernorm", + target_module=norm_cls, + kwargs={"sp_partial_derived": sp_partial_derived}, + ), + SubModuleReplacementDescription( + suffix="post_attention_layernorm", + target_module=norm_cls, + kwargs={"sp_partial_derived": sp_partial_derived}, + ), + ], + policy=policy, + target_key=LlamaDecoderLayer, + ) + + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="norm", + target_module=norm_cls, + kwargs={"sp_partial_derived": sp_partial_derived}, + ), + policy=policy, + target_key=LlamaModel, + ) + + # use flash attention + if use_flash_attention: + self.append_or_create_method_replacement( + description={ + "forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_group, sp_size), + }, + policy=policy, + target_key=attn_cls, + ) + if self.pipeline_stage_manager is None: + # replace llama model forward method + self.append_or_create_method_replacement( + description={ + "forward": get_llama_model_forward_for_flash_attn(self.shard_config), + }, + policy=policy, + target_key=LlamaModel, + ) + + return policy + + def postprocess(self): + return self.model + + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if self.pipeline_stage_manager is None: + return + + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == "LlamaModel": + module = self.model + else: + module = self.model.model + + if stage_manager.is_interleave: + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage) + method_replacement = { + "forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config) + } + + else: + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_index = stage_manager.get_stage_index(layers_per_stage) + method_replacement = { + "forward": partial( + new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config + ) + } + + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None + + if self.model.__class__.__name__ == "LlamaModel": + module = self.model + else: + module = self.model.model + stage_manager = self.pipeline_stage_manager + + held_layers = [] + if stage_manager.is_interleave: + assert stage_manager.num_model_chunks is not None + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_indices = stage_manager.get_stage_index(layers_per_stage) + if stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(module.embed_tokens) + for start_idx, end_idx in stage_indices: + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(module.norm) + + else: + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + if stage_manager.is_first_stage(): + held_layers.append(module.embed_tokens) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.norm) + + return held_layers + + +class LlamaModelPolicy(LlamaPolicy): + def module_policy(self): + policy = super().module_policy() + from transformers.models.llama.modeling_llama import LlamaModel + + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=LlamaModel, new_forward=LlamaPipelineForwards.llama_model_forward, policy=policy + ) + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + held_layers = super().get_held_layers() + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in llama model""" + return [] + + +class LlamaForCausalLMPolicy(LlamaPolicy): + def module_policy(self): + from transformers import LlamaForCausalLM + + policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + # add a new item for casual lm + new_item = { + LlamaForCausalLM: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=VocabParallelLMHead1D, + kwargs={ + "gather_output": not self.shard_config.parallel_output, + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + }, + ) + ], + ) + } + if self.shard_config.parallel_output: + new_item[LlamaForCausalLM].method_replacement = { + "forward": get_lm_forward_with_dist_cross_entropy(self.shard_config) + } + else: + new_item = { + LlamaForCausalLM: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=PaddingLMHead, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ) + ], + ) + } + policy.update(new_item) + + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=LlamaForCausalLM, new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, policy=policy + ) + + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(self.model.lm_head) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + llama_model = self.model.model + if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: + if ( + id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight) + and self.pipeline_stage_manager.num_stages > 1 + ): + # tie weights + return [ + { + 0: llama_model.embed_tokens.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, + } + ] + return [] + + +class LlamaForSequenceClassificationPolicy(LlamaPolicy): + def module_policy(self): + from transformers import LlamaForSequenceClassification + + policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + # add a new item for sequence classification + new_item = { + LlamaForSequenceClassification: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + ) + ] + ) + } + policy.update(new_item) + # to be confirmed + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=LlamaForSequenceClassification, + new_forward=LlamaPipelineForwards.llama_for_sequence_classification_forward, + policy=policy, + ) + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(self.model.score) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in llama for sequence classification model""" + return [] From 94fbde6055498332ee2b50a7443d8600a8312181 Mon Sep 17 00:00:00 2001 From: GuangyaoZhang Date: Fri, 14 Jun 2024 03:04:56 +0000 Subject: [PATCH 2/7] change command --- colossalai/shardformer/layer/__init__.py | 4 +- colossalai/shardformer/layer/normalization.py | 110 ++++ colossalai/shardformer/modeling/command.py | 491 ++++++------------ .../shardformer/policies/auto_policy.py | 7 + colossalai/shardformer/policies/command.py | 150 ++---- diff.output | 59 +++ tests/kit/model_zoo/transformers/__init__.py | 6 + tests/kit/model_zoo/transformers/command.py | 81 +++ .../test_model/test_shard_command.py | 301 +++++++++++ 9 files changed, 776 insertions(+), 433 deletions(-) create mode 100644 diff.output create mode 100644 tests/kit/model_zoo/transformers/command.py create mode 100644 tests/test_shardformer/test_model/test_shard_command.py diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index f17fad1b6..8c70a26b7 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -4,7 +4,7 @@ from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D from .loss import cross_entropy_1d -from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm +from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm, CohereLayerNorm, FusedCohereLayerNorm from .parallel_module import ParallelModule from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row @@ -23,6 +23,8 @@ __all__ = [ "RMSNorm", "FusedLayerNorm", "FusedRMSNorm", + "CohereLayerNorm", + "FusedCohereLayerNorm", "FusedLinear1D_Col", "ParallelModule", "PaddingEmbedding", diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py index 5aa212600..1f30c7741 100644 --- a/colossalai/shardformer/layer/normalization.py +++ b/colossalai/shardformer/layer/normalization.py @@ -4,6 +4,7 @@ import warnings from abc import ABC, abstractmethod import torch.nn as nn +from transformers.models.cohere.modeling_cohere import CohereLayerNorm from colossalai.lazy import LazyInitContext @@ -249,6 +250,115 @@ class FusedLayerNorm(BaseLayerNorm): return layernorm + +class CohereLayerNorm(BaseLayerNorm): + r""" + This is a wrapper around the transformers.models.cohere.CohereLayerNorm. It is meant to be used only with the from_native_module interface. + """ + + def __init__(self) -> None: + raise NotImplementedError( + "CohereLayerNorm is not implemented as a physical class. " + "It is meant to be used only with the from_native_module interface to convert a transformers.models.cohere.CohereLayerNorm module to colossalai layer norm module." + ) + + @staticmethod + def from_native_module(module: CohereLayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module: + r""" + Convert a CohereLayerNorm module to colossalai layer norm module, + and optionally marking parameters for gradient aggregation. + + Args: + module (transformers.models.cohere.CohereLayerNorm): The CohereLayerNorm module to be converted. + sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism. + + Returns: + nn.Module: The LayerNorm module. + + Raises: + AssertionError: If the provided module is not an instance of CohereLayerNorm + """ + + LazyInitContext.materialize(module) + + if sp_partial_derived: + # Since gradients are computed using only a subset of the data, + # aggregation of these gradients is necessary during backpropagation. + # Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation. + SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight) + + return module + + +class FusedCohereLayerNorm(BaseLayerNorm): + r""" + This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface. + """ + + def __init__(self) -> None: + raise NotImplementedError( + "FusedCohereLayerNorm is not implemented as a physical class. " + "It is meant to be used only with the from_native_module interface convert a transformers.models.cohere.CohereLayerNorm module to FusedLayerNorm module provided by apex." + ) + + @staticmethod + def from_native_module(module: CohereLayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module: + r""" + Convert a CohereLayerNorm module to FusedLayerNorm module provided by apex, + and optionally marking parameters for gradient aggregation. + + Args: + module (transformers.models.cohere.CohereLayerNorm): The CohereLayerNorm module to be converted. + sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism. + + Returns: + nn.Module: Union[FastLayerNorm, FusedLayerNorm]. + + Raises: + AssertionError: If the provided module is not an instance of transformers.models.cohere.CohereLayerNorm. + """ + + LazyInitContext.materialize(module) + # get the attributes of the module + normalized_shape = module.weight.size(0) + eps = module.variance_epsilon + elementwise_affine = True + dtype = module.weight.dtype + device = module.weight.device + + # pick the suitable layernorm implementation + use_fast_ln = normalized_shape in FAST_LAYERNORM_SUPPORTED_SIZE + + if use_fast_ln: + if EnableFastLayerNorm: + ApexFusedLayerNorm = FastLayerNormWithHook + else: + # fall back to the normal fused layernorm is not built + ApexFusedLayerNorm = FusedLayerNormWithHook + else: + try: + ApexFusedLayerNorm = FusedLayerNormWithHook + except NameError: + warnings.warn( + "Please install Apex from source to use fused kernels, or set self.enable_fused_normalization = False. Using vanilla layernorm instead." + ) + return module + + layernorm = ( + ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device) + ) + layernorm.weight = module.weight + + if sp_partial_derived: + # Since gradients are computed using only a subset of the data, + # aggregation of these gradients is necessary during backpropagation. + # Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation. + SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.weight) + SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.bias) + + return layernorm + + class FusedRMSNorm(BaseLayerNorm): """ This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface. diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py index 01d10c8dc..d0e6ed0a6 100644 --- a/colossalai/shardformer/modeling/command.py +++ b/colossalai/shardformer/modeling/command.py @@ -7,21 +7,16 @@ import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from transformers.cache_utils import Cache -from transformers.modeling_attn_mask_utils import ( - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, -) +from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast, ) -from transformers.models.llama.modeling_llama import ( - LlamaForCausalLM, - LlamaForSequenceClassification, - LlamaModel, - apply_rotary_pos_emb, +from transformers.models.cohere.modeling_cohere import ( + CohereForCausalLM, + CohereModel, + StaticCache, repeat_kv, ) from transformers.utils import logging @@ -37,15 +32,15 @@ from colossalai.shardformer.shard import ShardConfig from ..layer import ColoAttention, cross_entropy_1d -class LlamaPipelineForwards: +class CommandPipelineForwards: """ - This class serves as a micro library for forward function substitution of Llama models + This class serves as a micro library for forward function substitution of Command models under pipeline setting. """ @staticmethod - def llama_model_forward( - self: LlamaModel, + def command_model_forward( + self: CohereModel, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -55,6 +50,7 @@ class LlamaPipelineForwards: output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, @@ -68,6 +64,12 @@ class LlamaPipelineForwards: ) use_cache = use_cache if use_cache is not None else self.config.use_cache + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with pipeline parallelism. Setting `use_cache=False`..." + ) + use_cache = False + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds @@ -89,8 +91,17 @@ class LlamaPipelineForwards: batch_size, seq_length = input_shape device = hidden_states.device - seq_length_with_past = seq_length - past_key_values_length = 0 + past_seen_tokens = 0 + if use_cache: # kept for BC (cache positions) + if not isinstance(past_key_values, StaticCache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_seen_tokens = past_key_values.get_seq_length() + if cache_position is None: + if isinstance(past_key_values, StaticCache): + raise ValueError("cache_position is a required argument when using StaticCache.") + cache_position = torch.arange(past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=device) + + seq_length_with_past = seq_length + past_seen_tokens # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: @@ -103,18 +114,8 @@ class LlamaPipelineForwards: logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") use_cache = False - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - if position_ids is None: - position_ids = torch.arange( - past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device, - ) - position_ids = position_ids.unsqueeze(0) + position_ids = cache_position.unsqueeze(0) # embed positions, for the first stage, hidden_states is the input embeddings, # for the other stages, hidden_states is the output of the previous stage @@ -129,28 +130,9 @@ class LlamaPipelineForwards: is_causal=True, ) else: - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._use_sdpa and not output_attentions: - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - hidden_states, - past_key_values_length, - ) + attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position) - if self.gradient_checkpointing and self.training: + if self.gradient_checkpointing and self.training and use_cache: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." @@ -190,6 +172,7 @@ class LlamaPipelineForwards: past_key_values, output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( @@ -199,6 +182,7 @@ class LlamaPipelineForwards: past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] @@ -237,8 +221,8 @@ class LlamaPipelineForwards: return {"hidden_states": hidden_states} @staticmethod - def llama_for_causal_lm_forward( - self: LlamaForCausalLM, + def command_for_causal_lm_forward( + self: CohereForCausalLM, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -249,6 +233,7 @@ class LlamaPipelineForwards: output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, @@ -266,9 +251,9 @@ class LlamaPipelineForwards: Example: ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM + >>> from transformers import AutoTokenizer, CohereForCausalLM - >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> model = CohereForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) >>> prompt = "Hey, are you conscious? Can you talk to me?" @@ -295,7 +280,7 @@ class LlamaPipelineForwards: output_hidden_states = False # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = LlamaPipelineForwards.llama_model_forward( + outputs = CommandPipelineForwards.command_model_forward( self.model, input_ids=input_ids, attention_mask=attention_mask, @@ -306,6 +291,7 @@ class LlamaPipelineForwards: output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, stage_manager=stage_manager, hidden_states=hidden_states, stage_index=stage_index, @@ -316,6 +302,8 @@ class LlamaPipelineForwards: if stage_manager.is_last_stage(): hidden_states = outputs[0] logits = self.lm_head(hidden_states) + logits = logits * self.logit_scale + logits = logits.float() loss = None if labels is not None: # Shift so that tokens < n predict n @@ -355,137 +343,20 @@ class LlamaPipelineForwards: hidden_states = outputs.get("hidden_states") return {"hidden_states": hidden_states} - @staticmethod - def llama_for_sequence_classification_forward( - self: LlamaForSequenceClassification, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - shard_config: ShardConfig = None, - ): - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - logger = logging.get_logger(__name__) - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. - if output_attentions: - logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") - output_attentions = False - if output_hidden_states: - logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") - output_hidden_states = False - - transformer_outputs = LlamaPipelineForwards.llama_model_forward( - self.model, - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index, - shard_config=shard_config, - ) - - if input_ids is not None: - batch_size = input_ids.shape[0] - elif inputs_embeds is not None: - batch_size = inputs_embeds.shape[0] - else: - batch_size = hidden_states.shape[0] - - if stage_manager.is_last_stage(): - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - - if self.config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - if self.config.pad_token_id is None: - sequence_lengths = -1 - else: - if input_ids is not None: - sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) - else: - sequence_lengths = -1 - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] - - loss = None - if labels is not None: - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - else: - hidden_states = transformer_outputs.get("hidden_states") - return {"hidden_states": hidden_states} - - -def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size): - from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb - - try: - from transformers.models.llama.modeling_llama import repeat_kv - except: - warnings.warn("using llamav1, llamav1 hasn't repeat_kv function") +def get_command_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size): + from transformers.models.cohere.modeling_cohere import CohereAttention, apply_rotary_pos_emb + from transformers.models.cohere.modeling_cohere import repeat_kv + def forward( - self: LlamaAttention, + self: CohereAttention, hidden_states: torch.Tensor, attention_mask: Optional[dict] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if "padding_mask" in kwargs: @@ -520,13 +391,14 @@ def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size): "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -547,12 +419,12 @@ def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size): return forward -def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig): +def get_command_model_forward_for_flash_attn(shard_config: ShardConfig): logger = logging.get_logger(__name__) assert shard_config.enable_flash_attention, "Flash Attention is not enabled." def forward( - self: LlamaModel, + self: CohereModel, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -562,6 +434,7 @@ def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -572,41 +445,40 @@ def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig): return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device, + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() + use_cache = False if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + + past_seen_tokens = 0 + if use_cache: # kept for BC (cache positions) + if not isinstance(past_key_values, StaticCache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_seen_tokens = past_key_values.get_seq_length() + if cache_position is None: + if isinstance(past_key_values, StaticCache): + raise ValueError("cache_position is a required argument when using StaticCache.") + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + # embed positions hidden_states = inputs_embeds # in this case, attention_mask is a dict rather than a tensor - mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) + mask_shape = (hidden_states.shape[0], 1, past_seen_tokens, past_seen_tokens) attention_mask = ColoAttention.prepare_attn_kwargs( mask_shape, hidden_states.dtype, @@ -625,43 +497,38 @@ def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig): # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None - for idx, decoder_layer in enumerate(self.layers): + for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -672,7 +539,11 @@ def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig): if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None + next_cache = None + if use_cache: + next_cache = ( + next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache + ) if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( @@ -686,10 +557,10 @@ def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig): def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): - from transformers import LlamaForCausalLM + from transformers import CohereForCausalLM def forward( - self: LlamaForCausalLM, + self: CohereForCausalLM, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -700,6 +571,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -713,9 +585,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): Example: ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM + >>> from transformers import AutoTokenizer, CohereForCausalLM - >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> model = CohereForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) >>> prompt = "Hey, are you conscious? Can you talk to me?" @@ -744,15 +616,13 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = outputs[0] - if self.config.pretraining_tp > 1: - lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) - logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] - logits = torch.cat(logits, dim=-1) - else: - logits = self.lm_head(hidden_states) + + logits = self.lm_head(hidden_states) + logits = logits * self.logit_scale logits = logits.float() loss = None @@ -788,7 +658,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): return forward -def get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group): +def get_command_seq_parallel_attention_forward(sp_mode, sp_size, sp_group): + from transformers.models.cohere.modeling_cohere import apply_rotary_pos_emb + def forward( self, hidden_states: torch.Tensor, @@ -797,32 +669,16 @@ def get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group): past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() # sp: modify sp_len when sequence parallel mode is ring if sp_mode in ["split_gather", "ring"]: q_len *= sp_size - if self.config.pretraining_tp > 1: - key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 - ) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] - query_states = torch.cat(query_states, dim=-1) - - key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] - key_states = torch.cat(key_states, dim=-1) - - value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] - value_states = torch.cat(value_states, dim=-1) - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": @@ -835,18 +691,14 @@ def get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group): key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + past_key_value = getattr(self, "past_key_value", past_key_value) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -854,18 +706,9 @@ def get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group): attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) @@ -885,12 +728,8 @@ def get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group): else: attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - if self.config.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) - attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) - else: - attn_output = self.o_proj(attn_output) + + attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None @@ -899,11 +738,11 @@ def get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group): return forward -def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group): +def get_command_seq_parallel_model_forward(sp_mode, sp_size, sp_group): logger = logging.get_logger(__name__) def forward( - self, + self: CohereModel, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -913,6 +752,7 @@ def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -924,63 +764,43 @@ def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group): # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + raise ValueError( + "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time, and must specify either one" + ) - seq_length_with_past = seq_length - past_key_values_length = 0 + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - # modify past_key_values_length when using sequence parallel - past_key_values_length *= sp_size - seq_length_with_past = seq_length_with_past + past_key_values_length + if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device, + past_seen_tokens = 0 + if use_cache: # kept for BC (cache positions) + if not isinstance(past_key_values, StaticCache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_seen_tokens = past_key_values.get_seq_length() + if cache_position is None: + if isinstance(past_key_values, StaticCache): + raise ValueError("cache_position is a required argument when using StaticCache.") + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() + if position_ids is None: + position_ids = cache_position.unsqueeze(0) - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) + attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) if sp_mode in ["ring", "split_gather"]: inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) elif sp_mode == "all_to_all": inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), - dtype=torch.bool, - device=inputs_embeds.device, - ) - - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, attention_mask.shape, inputs_embeds, past_key_values_length - ) - hidden_states = inputs_embeds - if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None @@ -990,14 +810,12 @@ def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group): if output_hidden_states: all_hidden_states += (hidden_states,) - past_key_value = past_key_values[idx] if past_key_values is not None else None - if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training: def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value - return module(*inputs, past_key_value, output_attentions) + return module(*inputs, past_key_value=past_key_values, output_attentions=output_attentions) return custom_forward @@ -1013,15 +831,20 @@ def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group): hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + next_decoder_cache = ( + next_decoder_cache.to_legacy_cache() + if isinstance(next_decoder_cache, Cache) + else next_decoder_cache + ) if output_attentions: all_self_attns += (layer_outputs[1],) diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index 69df021b0..008dead6b 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -192,6 +192,13 @@ _POLICY_LIST = { "transformers.models.qwen2.modeling_qwen2.Qwen2ForSequenceClassification": PolicyLocation( file_name="qwen2", class_name="Qwen2ForSequenceClassificationPolicy" ), + # Command-R + "transformers.models.cohere.modeling_cohere.CohereModel": PolicyLocation( + file_name="command", class_name="CommandModelPolicy" + ), + "transformers.models.cohere.modeling_cohere.CohereForCausalLM": PolicyLocation( + file_name="command", class_name="CommandForCausalLMPolicy" + ), } diff --git a/colossalai/shardformer/policies/command.py b/colossalai/shardformer/policies/command.py index a9c982231..01fff3aa4 100644 --- a/colossalai/shardformer/policies/command.py +++ b/colossalai/shardformer/policies/command.py @@ -7,30 +7,30 @@ from torch import Tensor from torch.nn import Module from colossalai.shardformer.layer import ( - FusedRMSNorm, + FusedCohereLayerNorm, Linear1D_Col, Linear1D_Row, PaddingEmbedding, PaddingLMHead, - RMSNorm, + CohereLayerNorm, VocabParallelEmbedding1D, VocabParallelLMHead1D, ) -from ..modeling.llama import ( - LlamaPipelineForwards, - get_llama_flash_attention_forward, - get_llama_model_forward_for_flash_attn, - get_llama_seq_parallel_attention_forward, - get_llama_seq_parallel_model_forward, +from ..modeling.command import ( + CommandPipelineForwards, + get_command_flash_attention_forward, + get_command_model_forward_for_flash_attn, + get_command_seq_parallel_attention_forward, + get_command_seq_parallel_model_forward, get_lm_forward_with_dist_cross_entropy, ) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -__all__ = ["LlamaPolicy", "LlamaForCausalLMPolicy", "LlamaForSequenceClassificationPolicy"] +__all__ = ["CommandPolicy", "CommandForCausalLMPolicy"] -class LlamaPolicy(Policy): +class CommandPolicy(Policy): def config_sanity_check(self): pass @@ -40,18 +40,18 @@ class LlamaPolicy(Policy): return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - from transformers.models.llama.modeling_llama import ( - LlamaAttention, - LlamaDecoderLayer, - LlamaFlashAttention2, - LlamaModel, - LlamaSdpaAttention, + from transformers.models.cohere.modeling_cohere import ( + CohereAttention, + CohereDecoderLayer, + CohereFlashAttention2, + CohereModel, + CohereSdpaAttention, ) ATTN_IMPLEMENTATION = { - "eager": LlamaAttention, - "flash_attention_2": LlamaFlashAttention2, - "sdpa": LlamaSdpaAttention, + "eager": CohereAttention, + "flash_attention_2": CohereFlashAttention2, + "sdpa": CohereSdpaAttention, } policy = {} @@ -64,16 +64,16 @@ class LlamaPolicy(Policy): embedding_cls = PaddingEmbedding if self.shard_config.enable_fused_normalization: - norm_cls = FusedRMSNorm + norm_cls = FusedCohereLayerNorm else: - norm_cls = RMSNorm + norm_cls = CohereLayerNorm if self.pipeline_stage_manager is not None: self.shard_config.enable_sequence_parallelism = False self.shard_config.enable_sequence_overlap = False self.shard_config.sequence_parallelism_mode = None warnings.warn( - f"For llama, sequence parallelism is currently not compatible with pipeline parallelism, set to be False" + f"For Command, sequence parallelism is currently not compatible with pipeline parallelism, set to be False" ) sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None sp_size = self.shard_config.sequence_parallel_size if self.shard_config.enable_sequence_parallelism else None @@ -94,16 +94,16 @@ class LlamaPolicy(Policy): if sp_mode in ["split_gather", "ring"]: self.append_or_create_method_replacement( description={ - "forward": get_llama_seq_parallel_model_forward( + "forward": get_command_seq_parallel_model_forward( sp_mode=sp_mode, sp_size=sp_size, sp_group=sp_group ), }, policy=policy, - target_key=LlamaModel, + target_key=CohereModel, ) self.append_or_create_method_replacement( description={ - "forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group), + "forward": get_command_seq_parallel_attention_forward(sp_mode, sp_size, sp_group), }, policy=policy, target_key=attn_cls, @@ -120,21 +120,21 @@ class LlamaPolicy(Policy): ) self.append_or_create_method_replacement( description={ - "forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group), + "forward": get_command_seq_parallel_attention_forward(sp_mode, sp_size, sp_group), }, policy=policy, target_key=attn_cls, ) self.append_or_create_method_replacement( description={ - "forward": get_llama_seq_parallel_model_forward( + "forward": get_command_seq_parallel_model_forward( sp_mode=sp_mode, sp_size=sp_size, sp_group=sp_group, ), }, policy=policy, - target_key=LlamaModel, + target_key=CohereModel, ) if self.shard_config.enable_tensor_parallelism: @@ -155,7 +155,7 @@ class LlamaPolicy(Policy): self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size ) - policy[LlamaDecoderLayer] = ModulePolicyDescription( + policy[CohereDecoderLayer] = ModulePolicyDescription( attribute_replacement=decoder_attribute_replacement, sub_module_replacement=[ SubModuleReplacementDescription( @@ -204,7 +204,7 @@ class LlamaPolicy(Policy): kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, ), policy=policy, - target_key=LlamaModel, + target_key=CohereModel, ) # optimization configuration @@ -215,14 +215,9 @@ class LlamaPolicy(Policy): target_module=norm_cls, kwargs={"sp_partial_derived": sp_partial_derived}, ), - SubModuleReplacementDescription( - suffix="post_attention_layernorm", - target_module=norm_cls, - kwargs={"sp_partial_derived": sp_partial_derived}, - ), ], policy=policy, - target_key=LlamaDecoderLayer, + target_key=CohereDecoderLayer, ) self.append_or_create_submodule_replacement( @@ -232,26 +227,26 @@ class LlamaPolicy(Policy): kwargs={"sp_partial_derived": sp_partial_derived}, ), policy=policy, - target_key=LlamaModel, + target_key=CohereModel, ) # use flash attention if use_flash_attention: self.append_or_create_method_replacement( description={ - "forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_group, sp_size), + "forward": get_command_flash_attention_forward(self.shard_config, sp_mode, sp_group, sp_size), }, policy=policy, target_key=attn_cls, ) if self.pipeline_stage_manager is None: - # replace llama model forward method + # replace Command model forward method self.append_or_create_method_replacement( description={ - "forward": get_llama_model_forward_for_flash_attn(self.shard_config), + "forward": get_command_model_forward_for_flash_attn(self.shard_config), }, policy=policy, - target_key=LlamaModel, + target_key=CohereModel, ) return policy @@ -266,7 +261,7 @@ class LlamaPolicy(Policy): return stage_manager = self.pipeline_stage_manager - if self.model.__class__.__name__ == "LlamaModel": + if self.model.__class__.__name__ == "CohereModel": module = self.model else: module = self.model.model @@ -293,7 +288,7 @@ class LlamaPolicy(Policy): """Get pipeline layers for current stage.""" assert self.pipeline_stage_manager is not None - if self.model.__class__.__name__ == "LlamaModel": + if self.model.__class__.__name__ == "CohereModel": module = self.model else: module = self.model.model @@ -323,15 +318,15 @@ class LlamaPolicy(Policy): return held_layers -class LlamaModelPolicy(LlamaPolicy): +class CommandModelPolicy(CommandPolicy): def module_policy(self): policy = super().module_policy() - from transformers.models.llama.modeling_llama import LlamaModel + from transformers.models.cohere.modeling_cohere import CohereModel if self.pipeline_stage_manager: # set None as default self.set_pipeline_forward( - model_cls=LlamaModel, new_forward=LlamaPipelineForwards.llama_model_forward, policy=policy + model_cls=CohereModel, new_forward=CommandPipelineForwards.command_model_forward, policy=policy ) return policy @@ -341,20 +336,20 @@ class LlamaModelPolicy(LlamaPolicy): return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: - """No shared params in llama model""" + """No shared params in command model""" return [] -class LlamaForCausalLMPolicy(LlamaPolicy): +class CommandForCausalLMPolicy(CommandPolicy): def module_policy(self): - from transformers import LlamaForCausalLM + from transformers import CohereForCausalLM policy = super().module_policy() if self.shard_config.enable_tensor_parallelism: # add a new item for casual lm new_item = { - LlamaForCausalLM: ModulePolicyDescription( + CohereForCausalLM: ModulePolicyDescription( sub_module_replacement=[ SubModuleReplacementDescription( suffix="lm_head", @@ -368,12 +363,12 @@ class LlamaForCausalLMPolicy(LlamaPolicy): ) } if self.shard_config.parallel_output: - new_item[LlamaForCausalLM].method_replacement = { + new_item[CohereForCausalLM].method_replacement = { "forward": get_lm_forward_with_dist_cross_entropy(self.shard_config) } else: new_item = { - LlamaForCausalLM: ModulePolicyDescription( + CohereForCausalLM: ModulePolicyDescription( sub_module_replacement=[ SubModuleReplacementDescription( suffix="lm_head", @@ -388,7 +383,7 @@ class LlamaForCausalLMPolicy(LlamaPolicy): if self.pipeline_stage_manager: # set None as default self.set_pipeline_forward( - model_cls=LlamaForCausalLM, new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, policy=policy + model_cls=CohereForCausalLM, new_forward=CommandPipelineForwards.command_for_causal_lm_forward, policy=policy ) return policy @@ -402,58 +397,17 @@ class LlamaForCausalLMPolicy(LlamaPolicy): return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: - llama_model = self.model.model + command_model = self.model.model if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: if ( - id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight) + id(command_model.embed_tokens.weight) == id(self.model.lm_head.weight) and self.pipeline_stage_manager.num_stages > 1 ): # tie weights return [ { - 0: llama_model.embed_tokens.weight, + 0: command_model.embed_tokens.weight, self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, } ] - return [] - - -class LlamaForSequenceClassificationPolicy(LlamaPolicy): - def module_policy(self): - from transformers import LlamaForSequenceClassification - - policy = super().module_policy() - - if self.shard_config.enable_tensor_parallelism: - # add a new item for sequence classification - new_item = { - LlamaForSequenceClassification: ModulePolicyDescription( - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True) - ) - ] - ) - } - policy.update(new_item) - # to be confirmed - if self.pipeline_stage_manager: - # set None as default - self.set_pipeline_forward( - model_cls=LlamaForSequenceClassification, - new_forward=LlamaPipelineForwards.llama_for_sequence_classification_forward, - policy=policy, - ) - return policy - - def get_held_layers(self) -> List[Module]: - """Get pipeline layers for current stage.""" - stage_manager = self.pipeline_stage_manager - held_layers = super().get_held_layers() - if stage_manager.is_last_stage(ignore_chunk=True): - held_layers.append(self.model.score) - return held_layers - - def get_shared_params(self) -> List[Dict[int, Tensor]]: - """No shared params in llama for sequence classification model""" - return [] + return [] \ No newline at end of file diff --git a/diff.output b/diff.output new file mode 100644 index 000000000..638edfee8 --- /dev/null +++ b/diff.output @@ -0,0 +1,59 @@ +diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py +index 5aa21260..01453a05 100644 +--- a/colossalai/shardformer/layer/normalization.py ++++ b/colossalai/shardformer/layer/normalization.py +@@ -165,7 +165,7 @@ class LayerNorm(BaseLayerNorm): + Raises: + AssertionError: If the provided module is not an instance of nn.LayerNorm. + """ +- assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm." ++ # assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm." + + LazyInitContext.materialize(module) + +@@ -174,7 +174,7 @@ class LayerNorm(BaseLayerNorm): + # aggregation of these gradients is necessary during backpropagation. + # Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation. + SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight) +- SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias) ++ # SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias) + + return module + +@@ -209,9 +209,12 @@ class FusedLayerNorm(BaseLayerNorm): + + LazyInitContext.materialize(module) + # get the attributes of the module +- normalized_shape = module.normalized_shape +- eps = module.eps +- elementwise_affine = module.elementwise_affine ++ # normalized_shape = module.normalized_shape ++ # eps = module.eps ++ # elementwise_affine = module.elementwise_affine ++ normalized_shape = module.weight.size(0) ++ eps = module.variance_epsilon ++ elementwise_affine = True + dtype = module.weight.dtype + device = module.weight.device + +@@ -244,7 +247,7 @@ class FusedLayerNorm(BaseLayerNorm): + # aggregation of these gradients is necessary during backpropagation. + # Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation. + SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.weight) +- SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.bias) ++ # SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.bias) + + return layernorm + +diff --git a/tests/test_shardformer/test_model/test_shard_command.py b/tests/test_shardformer/test_model/test_shard_command.py +index 6075f836..a7166e38 100644 +--- a/tests/test_shardformer/test_model/test_shard_command.py ++++ b/tests/test_shardformer/test_model/test_shard_command.py +@@ -210,6 +210,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, + ], + ) + def run_command_test(test_config): ++ print(test_config) + sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py index d5bddcff0..05c17f562 100644 --- a/tests/kit/model_zoo/transformers/__init__.py +++ b/tests/kit/model_zoo/transformers/__init__.py @@ -22,3 +22,9 @@ try: from .qwen2 import * except ImportError: print("This version of transformers doesn't support qwen2.") + + +try: + from .command import * +except ImportError: + print("This version of transformers doesn't support Command-R.") diff --git a/tests/kit/model_zoo/transformers/command.py b/tests/kit/model_zoo/transformers/command.py new file mode 100644 index 000000000..6b15792b4 --- /dev/null +++ b/tests/kit/model_zoo/transformers/command.py @@ -0,0 +1,81 @@ +import torch +import transformers + +from ..registry import ModelAttribute, model_zoo + +try: + from transformers import CohereConfig + + HAS_COMMAND = True +except ImportError: + HAS_COMMAND = False + +if HAS_COMMAND: + # =============================== + # Register Command-R + # =============================== + + def data_gen(): + + + input_ids = torch.Tensor( + [ + [1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082], + [1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082], + ] + ).long() + + attention_mask = torch.Tensor( + [ + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + ] + ).long() + + return dict(input_ids=input_ids, attention_mask=attention_mask) + + # label is needed for casual lm + def data_gen_for_casual_lm(): + data = data_gen() + labels = data["input_ids"].clone() + data["labels"] = labels + return data + + # transform the output to a dict + output_transform_fn = lambda x: x + + # function to get the loss + loss_fn = lambda output: output["last_hidden_state"].mean() + loss_fn_for_casual_lm = lambda output: output["loss"] + loss_fn_for_seq_classification = lambda output: output["logits"].mean() + + config = CohereConfig( + num_hidden_layers=8, + hidden_size=32, + intermediate_size=64, + num_attention_heads=4, + max_position_embeddings=128, + ) + + if hasattr(config, "pad_token_id"): + config.pad_token_id = config.eos_token_id + + # register the following models + # transformers.CohereModel, + # transformers.CohereForCausalLM, + model_zoo.register( + name="transformers_command", + model_fn=lambda: transformers.CohereModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), + ) + model_zoo.register( + name="transformers_command_for_casual_lm", + model_fn=lambda: transformers.CohereForCausalLM(config), + data_gen_fn=data_gen_for_casual_lm, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_casual_lm, + model_attribute=ModelAttribute(has_control_flow=True), + ) diff --git a/tests/test_shardformer/test_model/test_shard_command.py b/tests/test_shardformer/test_model/test_shard_command.py new file mode 100644 index 000000000..6075f8369 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_command.py @@ -0,0 +1,301 @@ +import os + +import pytest +import torch +import torch.distributed as dist +from torch.testing import assert_close + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import PipelineGradientCheckpointConfig +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_all_grad_tensors, + check_loss, + check_output_hidden_state, + check_weight, + get_grad_tensors_for_check, + run_forward_backward_with_hybrid_plugin, + unwrap_model, +) + +os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" + + +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + enable_gradient_checkpointing = test_config.pop("enable_gradient_checkpointing", False) + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config + ) + if enable_gradient_checkpointing: + # org_model.gradient_checkpointing_enable() + sharded_model.unwrap().gradient_checkpointing_enable() + + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) + + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group + + # unwrap model + command_model = unwrap_model(org_model, "CohereModel", "model") + shard_command_model = unwrap_model(sharded_model, "CohereModel", "model") + + row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"] + col_layer_for_check = ["layers[0].self_attn.o_proj"] + # Here we check the grad of layernorm because an all-reduce operation should be performed during sequence parallelism + norm_layer_for_check = ["layers[0].input_layernorm", "layers[1].input_layernorm"] + + # During pipeline parallelism, we cannot get the grad of norm layer during first stage, so we only check this when pp is not enbaled + if stage_manager is None: + norm_layer_for_check.append("norm") + + # Check the grad when using ZeRO-1 and ZeRO-2 + if ( + booster.plugin.zero_stage in [1, 2] + and booster.plugin.shard_config.enable_sequence_parallelism + and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all" + ): + for p1, p2 in zip(command_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]): + working_p = sharded_optimizer._param_store.master_to_working_param[id(p2)] + grads = sharded_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(working_p)) + grad_index = ( + 0 if sharded_optimizer._grad_store._partition_grads else sharded_optimizer._bucket_store.zero_local_rank + ) + grad = grads[grad_index] + sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()] + assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False) + + # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. + grads_to_check = {} + if (stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True)) and booster.plugin.zero_stage == 0: + if test_config["precision"] == "fp32": + atol, rtol = 1e-6, 1e-4 + else: + atol, rtol = 5e-3, 5e-3 + row_layer_grads = get_grad_tensors_for_check( + command_model, shard_command_model, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False + ) + col_layer_grads = get_grad_tensors_for_check( + command_model, shard_command_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) + norm_layer_grads = get_grad_tensors_for_check( + command_model, + shard_command_model, + norm_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, + ) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + grads_to_check.update(norm_layer_grads) + + # optimizer executes step + org_optimizer.step() + sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(ignore_chunk=True): + if test_config["precision"] == "fp32": + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + + if org_model.__class__.__name__ == "CohereModel": + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + # check weights + if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True): + if test_config["precision"] == "fp32": + atol, rtol = 1e-4, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + check_weight( + command_model, shard_command_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) + + # check grads + check_all_grad_tensors(grads_to_check) + + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring", + "enable_flash_attention": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "split_gather", + "enable_flash_attention": False, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 1, + "sp_size": 2, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + "enable_gradient_checkpointing": True, + "gradient_checkpoint_config": PipelineGradientCheckpointConfig(gradient_checkpointing_ratio=0.5), + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "use_lazy_init": False, + "precision": "fp32", + "enable_gradient_checkpointing": True, + "gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]), + }, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + ], +) +def run_command_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp16", + "zero_stage": 1, + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "pp_style": "interleaved", + "num_model_chunks": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "precision": "fp16", + "zero_stage": 1, + "initial_scale": 1, + "enable_gradient_checkpointing": True, + "gradient_checkpoint_config": PipelineGradientCheckpointConfig( + num_ckpt_layers_per_stage=[0, 1, 2, 2], + ), + }, + ], +) +def run_command_3d_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + +def check_command(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_command_test() + + +def check_command_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_command_3d_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_command(): + spawn(check_command, 4) + + +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_command_3d(): + spawn(check_command_3d, 8) + + +if __name__ == "__main__": + test_command() + test_command_3d() From 1016bb32572a554de5e5966a7a772f45fb0f5d01 Mon Sep 17 00:00:00 2001 From: GuangyaoZhang Date: Fri, 14 Jun 2024 08:04:29 +0000 Subject: [PATCH 3/7] Fix Code Factor check --- tests/test_shardformer/test_model/test_shard_command.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_shardformer/test_model/test_shard_command.py b/tests/test_shardformer/test_model/test_shard_command.py index 6075f8369..c4b640d97 100644 --- a/tests/test_shardformer/test_model/test_shard_command.py +++ b/tests/test_shardformer/test_model/test_shard_command.py @@ -210,7 +210,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ], ) def run_command_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm") + sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) From 2a7fa2e7d08bc2cb5cd67438489793ddff742ee4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 14 Jun 2024 08:05:06 +0000 Subject: [PATCH 4/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/shardformer/layer/__init__.py | 2 +- colossalai/shardformer/layer/normalization.py | 1 - colossalai/shardformer/modeling/command.py | 22 ++++----------- colossalai/shardformer/policies/command.py | 8 +++--- diff.output | 18 ++++++------- tests/kit/model_zoo/transformers/command.py | 2 -- .../test_model/test_shard_command.py | 27 ++++++++++++++++--- 7 files changed, 44 insertions(+), 36 deletions(-) diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 8c70a26b7..33e500034 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -4,7 +4,7 @@ from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D from .loss import cross_entropy_1d -from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm, CohereLayerNorm, FusedCohereLayerNorm +from .normalization import CohereLayerNorm, FusedCohereLayerNorm, FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm from .parallel_module import ParallelModule from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py index 1f30c7741..34a126904 100644 --- a/colossalai/shardformer/layer/normalization.py +++ b/colossalai/shardformer/layer/normalization.py @@ -250,7 +250,6 @@ class FusedLayerNorm(BaseLayerNorm): return layernorm - class CohereLayerNorm(BaseLayerNorm): r""" This is a wrapper around the transformers.models.cohere.CohereLayerNorm. It is meant to be used only with the from_native_module interface. diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py index d0e6ed0a6..85cf551b6 100644 --- a/colossalai/shardformer/modeling/command.py +++ b/colossalai/shardformer/modeling/command.py @@ -3,22 +3,12 @@ import warnings from typing import List, Optional, Tuple, Union import torch -import torch.nn.functional as F import torch.utils.checkpoint from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from torch.nn import CrossEntropyLoss from transformers.cache_utils import Cache, DynamicCache -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, - SequenceClassifierOutputWithPast, -) -from transformers.models.cohere.modeling_cohere import ( - CohereForCausalLM, - CohereModel, - StaticCache, - repeat_kv, -) +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.models.cohere.modeling_cohere import CohereForCausalLM, CohereModel, StaticCache, repeat_kv from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager @@ -343,10 +333,9 @@ class CommandPipelineForwards: hidden_states = outputs.get("hidden_states") return {"hidden_states": hidden_states} + def get_command_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size): - from transformers.models.cohere.modeling_cohere import CohereAttention, apply_rotary_pos_emb - from transformers.models.cohere.modeling_cohere import repeat_kv - + from transformers.models.cohere.modeling_cohere import CohereAttention, apply_rotary_pos_emb, repeat_kv def forward( self: CohereAttention, @@ -728,7 +717,6 @@ def get_command_seq_parallel_attention_forward(sp_mode, sp_size, sp_group): else: attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - attn_output = self.o_proj(attn_output) if not output_attentions: diff --git a/colossalai/shardformer/policies/command.py b/colossalai/shardformer/policies/command.py index 01fff3aa4..6c4785912 100644 --- a/colossalai/shardformer/policies/command.py +++ b/colossalai/shardformer/policies/command.py @@ -7,12 +7,12 @@ from torch import Tensor from torch.nn import Module from colossalai.shardformer.layer import ( + CohereLayerNorm, FusedCohereLayerNorm, Linear1D_Col, Linear1D_Row, PaddingEmbedding, PaddingLMHead, - CohereLayerNorm, VocabParallelEmbedding1D, VocabParallelLMHead1D, ) @@ -383,7 +383,9 @@ class CommandForCausalLMPolicy(CommandPolicy): if self.pipeline_stage_manager: # set None as default self.set_pipeline_forward( - model_cls=CohereForCausalLM, new_forward=CommandPipelineForwards.command_for_causal_lm_forward, policy=policy + model_cls=CohereForCausalLM, + new_forward=CommandPipelineForwards.command_for_causal_lm_forward, + policy=policy, ) return policy @@ -410,4 +412,4 @@ class CommandForCausalLMPolicy(CommandPolicy): self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, } ] - return [] \ No newline at end of file + return [] diff --git a/diff.output b/diff.output index 638edfee8..0a84014f5 100644 --- a/diff.output +++ b/diff.output @@ -8,20 +8,20 @@ index 5aa21260..01453a05 100644 """ - assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm." + # assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm." - + LazyInitContext.materialize(module) - + @@ -174,7 +174,7 @@ class LayerNorm(BaseLayerNorm): # aggregation of these gradients is necessary during backpropagation. # Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation. SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight) - SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias) + # SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias) - + return module - + @@ -209,9 +209,12 @@ class FusedLayerNorm(BaseLayerNorm): - + LazyInitContext.materialize(module) # get the attributes of the module - normalized_shape = module.normalized_shape @@ -35,16 +35,16 @@ index 5aa21260..01453a05 100644 + elementwise_affine = True dtype = module.weight.dtype device = module.weight.device - + @@ -244,7 +247,7 @@ class FusedLayerNorm(BaseLayerNorm): # aggregation of these gradients is necessary during backpropagation. # Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation. SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.weight) - SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.bias) + # SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.bias) - + return layernorm - + diff --git a/tests/test_shardformer/test_model/test_shard_command.py b/tests/test_shardformer/test_model/test_shard_command.py index 6075f836..a7166e38 100644 --- a/tests/test_shardformer/test_model/test_shard_command.py @@ -55,5 +55,5 @@ index 6075f836..a7166e38 100644 def run_command_test(test_config): + print(test_config) sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm") - + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): diff --git a/tests/kit/model_zoo/transformers/command.py b/tests/kit/model_zoo/transformers/command.py index 6b15792b4..a8b8842c5 100644 --- a/tests/kit/model_zoo/transformers/command.py +++ b/tests/kit/model_zoo/transformers/command.py @@ -16,8 +16,6 @@ if HAS_COMMAND: # =============================== def data_gen(): - - input_ids = torch.Tensor( [ [1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082], diff --git a/tests/test_shardformer/test_model/test_shard_command.py b/tests/test_shardformer/test_model/test_shard_command.py index c4b640d97..32c67d60e 100644 --- a/tests/test_shardformer/test_model/test_shard_command.py +++ b/tests/test_shardformer/test_model/test_shard_command.py @@ -79,10 +79,24 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, else: atol, rtol = 5e-3, 5e-3 row_layer_grads = get_grad_tensors_for_check( - command_model, shard_command_model, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False + command_model, + shard_command_model, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False, ) col_layer_grads = get_grad_tensors_for_check( - command_model, shard_command_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + command_model, + shard_command_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, ) norm_layer_grads = get_grad_tensors_for_check( command_model, @@ -121,7 +135,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, else: atol, rtol = 5e-3, 5e-3 check_weight( - command_model, shard_command_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + command_model, + shard_command_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, ) # check grads From 9a290ab01333d63a331d43825acffdf114f30725 Mon Sep 17 00:00:00 2001 From: GuangyaoZhang Date: Fri, 14 Jun 2024 08:09:24 +0000 Subject: [PATCH 5/7] fix precommit --- colossalai/shardformer/layer/__init__.py | 2 +- colossalai/shardformer/layer/normalization.py | 1 - colossalai/shardformer/modeling/command.py | 22 ++----- colossalai/shardformer/policies/command.py | 8 ++- diff.output | 59 ------------------- tests/kit/model_zoo/transformers/command.py | 2 - .../test_model/test_shard_command.py | 27 ++++++++- 7 files changed, 35 insertions(+), 86 deletions(-) delete mode 100644 diff.output diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 8c70a26b7..33e500034 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -4,7 +4,7 @@ from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D from .loss import cross_entropy_1d -from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm, CohereLayerNorm, FusedCohereLayerNorm +from .normalization import CohereLayerNorm, FusedCohereLayerNorm, FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm from .parallel_module import ParallelModule from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py index 1f30c7741..34a126904 100644 --- a/colossalai/shardformer/layer/normalization.py +++ b/colossalai/shardformer/layer/normalization.py @@ -250,7 +250,6 @@ class FusedLayerNorm(BaseLayerNorm): return layernorm - class CohereLayerNorm(BaseLayerNorm): r""" This is a wrapper around the transformers.models.cohere.CohereLayerNorm. It is meant to be used only with the from_native_module interface. diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py index d0e6ed0a6..85cf551b6 100644 --- a/colossalai/shardformer/modeling/command.py +++ b/colossalai/shardformer/modeling/command.py @@ -3,22 +3,12 @@ import warnings from typing import List, Optional, Tuple, Union import torch -import torch.nn.functional as F import torch.utils.checkpoint from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from torch.nn import CrossEntropyLoss from transformers.cache_utils import Cache, DynamicCache -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, - SequenceClassifierOutputWithPast, -) -from transformers.models.cohere.modeling_cohere import ( - CohereForCausalLM, - CohereModel, - StaticCache, - repeat_kv, -) +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.models.cohere.modeling_cohere import CohereForCausalLM, CohereModel, StaticCache, repeat_kv from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager @@ -343,10 +333,9 @@ class CommandPipelineForwards: hidden_states = outputs.get("hidden_states") return {"hidden_states": hidden_states} + def get_command_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size): - from transformers.models.cohere.modeling_cohere import CohereAttention, apply_rotary_pos_emb - from transformers.models.cohere.modeling_cohere import repeat_kv - + from transformers.models.cohere.modeling_cohere import CohereAttention, apply_rotary_pos_emb, repeat_kv def forward( self: CohereAttention, @@ -728,7 +717,6 @@ def get_command_seq_parallel_attention_forward(sp_mode, sp_size, sp_group): else: attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - attn_output = self.o_proj(attn_output) if not output_attentions: diff --git a/colossalai/shardformer/policies/command.py b/colossalai/shardformer/policies/command.py index 01fff3aa4..6c4785912 100644 --- a/colossalai/shardformer/policies/command.py +++ b/colossalai/shardformer/policies/command.py @@ -7,12 +7,12 @@ from torch import Tensor from torch.nn import Module from colossalai.shardformer.layer import ( + CohereLayerNorm, FusedCohereLayerNorm, Linear1D_Col, Linear1D_Row, PaddingEmbedding, PaddingLMHead, - CohereLayerNorm, VocabParallelEmbedding1D, VocabParallelLMHead1D, ) @@ -383,7 +383,9 @@ class CommandForCausalLMPolicy(CommandPolicy): if self.pipeline_stage_manager: # set None as default self.set_pipeline_forward( - model_cls=CohereForCausalLM, new_forward=CommandPipelineForwards.command_for_causal_lm_forward, policy=policy + model_cls=CohereForCausalLM, + new_forward=CommandPipelineForwards.command_for_causal_lm_forward, + policy=policy, ) return policy @@ -410,4 +412,4 @@ class CommandForCausalLMPolicy(CommandPolicy): self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, } ] - return [] \ No newline at end of file + return [] diff --git a/diff.output b/diff.output deleted file mode 100644 index 638edfee8..000000000 --- a/diff.output +++ /dev/null @@ -1,59 +0,0 @@ -diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py -index 5aa21260..01453a05 100644 ---- a/colossalai/shardformer/layer/normalization.py -+++ b/colossalai/shardformer/layer/normalization.py -@@ -165,7 +165,7 @@ class LayerNorm(BaseLayerNorm): - Raises: - AssertionError: If the provided module is not an instance of nn.LayerNorm. - """ -- assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm." -+ # assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm." - - LazyInitContext.materialize(module) - -@@ -174,7 +174,7 @@ class LayerNorm(BaseLayerNorm): - # aggregation of these gradients is necessary during backpropagation. - # Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation. - SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight) -- SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias) -+ # SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias) - - return module - -@@ -209,9 +209,12 @@ class FusedLayerNorm(BaseLayerNorm): - - LazyInitContext.materialize(module) - # get the attributes of the module -- normalized_shape = module.normalized_shape -- eps = module.eps -- elementwise_affine = module.elementwise_affine -+ # normalized_shape = module.normalized_shape -+ # eps = module.eps -+ # elementwise_affine = module.elementwise_affine -+ normalized_shape = module.weight.size(0) -+ eps = module.variance_epsilon -+ elementwise_affine = True - dtype = module.weight.dtype - device = module.weight.device - -@@ -244,7 +247,7 @@ class FusedLayerNorm(BaseLayerNorm): - # aggregation of these gradients is necessary during backpropagation. - # Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation. - SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.weight) -- SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.bias) -+ # SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.bias) - - return layernorm - -diff --git a/tests/test_shardformer/test_model/test_shard_command.py b/tests/test_shardformer/test_model/test_shard_command.py -index 6075f836..a7166e38 100644 ---- a/tests/test_shardformer/test_model/test_shard_command.py -+++ b/tests/test_shardformer/test_model/test_shard_command.py -@@ -210,6 +210,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, - ], - ) - def run_command_test(test_config): -+ print(test_config) - sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm") - - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): diff --git a/tests/kit/model_zoo/transformers/command.py b/tests/kit/model_zoo/transformers/command.py index 6b15792b4..a8b8842c5 100644 --- a/tests/kit/model_zoo/transformers/command.py +++ b/tests/kit/model_zoo/transformers/command.py @@ -16,8 +16,6 @@ if HAS_COMMAND: # =============================== def data_gen(): - - input_ids = torch.Tensor( [ [1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082], diff --git a/tests/test_shardformer/test_model/test_shard_command.py b/tests/test_shardformer/test_model/test_shard_command.py index c4b640d97..32c67d60e 100644 --- a/tests/test_shardformer/test_model/test_shard_command.py +++ b/tests/test_shardformer/test_model/test_shard_command.py @@ -79,10 +79,24 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, else: atol, rtol = 5e-3, 5e-3 row_layer_grads = get_grad_tensors_for_check( - command_model, shard_command_model, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False + command_model, + shard_command_model, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False, ) col_layer_grads = get_grad_tensors_for_check( - command_model, shard_command_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + command_model, + shard_command_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, ) norm_layer_grads = get_grad_tensors_for_check( command_model, @@ -121,7 +135,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, else: atol, rtol = 5e-3, 5e-3 check_weight( - command_model, shard_command_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + command_model, + shard_command_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, ) # check grads From 8c3f524660c62ff51583082412e83997bb171ca2 Mon Sep 17 00:00:00 2001 From: GuangyaoZhang Date: Fri, 14 Jun 2024 09:14:01 +0000 Subject: [PATCH 6/7] Remove CohereLayerNorm and use existing layernorm --- colossalai/shardformer/layer/__init__.py | 4 +- colossalai/shardformer/layer/normalization.py | 146 ++---------------- colossalai/shardformer/policies/command.py | 8 +- 3 files changed, 22 insertions(+), 136 deletions(-) diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 33e500034..f17fad1b6 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -4,7 +4,7 @@ from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D from .loss import cross_entropy_1d -from .normalization import CohereLayerNorm, FusedCohereLayerNorm, FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm +from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm from .parallel_module import ParallelModule from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row @@ -23,8 +23,6 @@ __all__ = [ "RMSNorm", "FusedLayerNorm", "FusedRMSNorm", - "CohereLayerNorm", - "FusedCohereLayerNorm", "FusedLinear1D_Col", "ParallelModule", "PaddingEmbedding", diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py index 34a126904..59e1da9fc 100644 --- a/colossalai/shardformer/layer/normalization.py +++ b/colossalai/shardformer/layer/normalization.py @@ -4,7 +4,6 @@ import warnings from abc import ABC, abstractmethod import torch.nn as nn -from transformers.models.cohere.modeling_cohere import CohereLayerNorm from colossalai.lazy import LazyInitContext @@ -141,32 +140,29 @@ class RMSNorm(BaseLayerNorm): class LayerNorm(BaseLayerNorm): r""" - This is a wrapper around the torch.nn.LayerNorm. It is meant to be used only with the from_native_module interface. + This is a wrapper around native LayerNorm. It is meant to be used only with the from_native_module interface. """ def __init__(self) -> None: raise NotImplementedError( "LayerNorm is not implemented as a physical class. " - "It is meant to be used only with the from_native_module interface to convert a native pytorch layer norm module to colossalai layer norm module." + "It is meant to be used only with the from_native_module interface to convert a native LayerNorm module to colossalai layer norm module." ) @staticmethod - def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module: + def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module: r""" - Convert a native pytorch layer norm module to colossalai layer norm module, + Convert a native LayerNorm module to colossalai layer norm module, and optionally marking parameters for gradient aggregation. Args: - module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted. + module (nn.Module): The native LayerNorm module to be converted. sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism. Returns: - nn.Module: The LayerNorm module. + nn.Module: The colossalai LayerNorm module. - Raises: - AssertionError: If the provided module is not an instance of nn.LayerNorm. """ - assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm." LazyInitContext.materialize(module) @@ -175,7 +171,8 @@ class LayerNorm(BaseLayerNorm): # aggregation of these gradients is necessary during backpropagation. # Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation. SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight) - SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias) + if module.bias is not None: + SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias) return module @@ -188,140 +185,29 @@ class FusedLayerNorm(BaseLayerNorm): def __init__(self) -> None: raise NotImplementedError( "FusedLayerNorm is not implemented as a physical class. " - "It is meant to be used only with the from_native_module interface convert a native pytorch layer norm module to FusedLayerNorm module provided by apex." + "It is meant to be used only with the from_native_module interface convert a native LayerNorm module to FusedLayerNorm module provided by apex." ) @staticmethod def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module: r""" - Convert a native pytorch layer norm module to FusedLayerNorm module provided by apex, + Convert a native LayerNorm module to FusedLayerNorm module provided by apex, and optionally marking parameters for gradient aggregation. Args: - module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted. + module (nn.Module): The native LayerNorm module to be converted. sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism. Returns: nn.Module: Union[FastLayerNorm, FusedLayerNorm]. - Raises: - AssertionError: If the provided module is not an instance of nn.LayerNorm. """ LazyInitContext.materialize(module) # get the attributes of the module - normalized_shape = module.normalized_shape - eps = module.eps - elementwise_affine = module.elementwise_affine - dtype = module.weight.dtype - device = module.weight.device - - # pick the suitable layernorm implementation - use_fast_ln = normalized_shape in FAST_LAYERNORM_SUPPORTED_SIZE - - if use_fast_ln: - if EnableFastLayerNorm: - ApexFusedLayerNorm = FastLayerNormWithHook - else: - # fall back to the normal fused layernorm is not built - ApexFusedLayerNorm = FusedLayerNormWithHook - else: - try: - ApexFusedLayerNorm = FusedLayerNormWithHook - except NameError: - warnings.warn( - "Please install Apex from source to use fused kernels, or set self.enable_fused_normalization = False. Using vanilla layernorm instead." - ) - return module - - layernorm = ( - ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device) - ) - layernorm.weight = module.weight - layernorm.bias = module.bias - - if sp_partial_derived: - # Since gradients are computed using only a subset of the data, - # aggregation of these gradients is necessary during backpropagation. - # Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation. - SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.weight) - SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.bias) - - return layernorm - - -class CohereLayerNorm(BaseLayerNorm): - r""" - This is a wrapper around the transformers.models.cohere.CohereLayerNorm. It is meant to be used only with the from_native_module interface. - """ - - def __init__(self) -> None: - raise NotImplementedError( - "CohereLayerNorm is not implemented as a physical class. " - "It is meant to be used only with the from_native_module interface to convert a transformers.models.cohere.CohereLayerNorm module to colossalai layer norm module." - ) - - @staticmethod - def from_native_module(module: CohereLayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module: - r""" - Convert a CohereLayerNorm module to colossalai layer norm module, - and optionally marking parameters for gradient aggregation. - - Args: - module (transformers.models.cohere.CohereLayerNorm): The CohereLayerNorm module to be converted. - sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism. - - Returns: - nn.Module: The LayerNorm module. - - Raises: - AssertionError: If the provided module is not an instance of CohereLayerNorm - """ - - LazyInitContext.materialize(module) - - if sp_partial_derived: - # Since gradients are computed using only a subset of the data, - # aggregation of these gradients is necessary during backpropagation. - # Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation. - SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight) - - return module - - -class FusedCohereLayerNorm(BaseLayerNorm): - r""" - This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface. - """ - - def __init__(self) -> None: - raise NotImplementedError( - "FusedCohereLayerNorm is not implemented as a physical class. " - "It is meant to be used only with the from_native_module interface convert a transformers.models.cohere.CohereLayerNorm module to FusedLayerNorm module provided by apex." - ) - - @staticmethod - def from_native_module(module: CohereLayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module: - r""" - Convert a CohereLayerNorm module to FusedLayerNorm module provided by apex, - and optionally marking parameters for gradient aggregation. - - Args: - module (transformers.models.cohere.CohereLayerNorm): The CohereLayerNorm module to be converted. - sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism. - - Returns: - nn.Module: Union[FastLayerNorm, FusedLayerNorm]. - - Raises: - AssertionError: If the provided module is not an instance of transformers.models.cohere.CohereLayerNorm. - """ - - LazyInitContext.materialize(module) - # get the attributes of the module - normalized_shape = module.weight.size(0) - eps = module.variance_epsilon - elementwise_affine = True + normalized_shape = getattr(module, "normalized_shape", module.weight.shape[0]) + eps = module.variance_epsilon if hasattr(module, "variance_epsilon") else module.eps + elementwise_affine = getattr(module, "elementwise_affine", True) dtype = module.weight.dtype device = module.weight.device @@ -339,7 +225,7 @@ class FusedCohereLayerNorm(BaseLayerNorm): ApexFusedLayerNorm = FusedLayerNormWithHook except NameError: warnings.warn( - "Please install Apex from source to use fused kernels, or set self.enable_fused_normalization = False. Using vanilla layernorm instead." + "Please install Apex from source to use fused kernels, or set self.enable_fused_normalization = False. Using native layernorm instead." ) return module @@ -347,6 +233,8 @@ class FusedCohereLayerNorm(BaseLayerNorm): ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device) ) layernorm.weight = module.weight + if module.bias is not None: + layernorm.bias = module.bias if sp_partial_derived: # Since gradients are computed using only a subset of the data, diff --git a/colossalai/shardformer/policies/command.py b/colossalai/shardformer/policies/command.py index 6c4785912..e2a367f74 100644 --- a/colossalai/shardformer/policies/command.py +++ b/colossalai/shardformer/policies/command.py @@ -7,8 +7,8 @@ from torch import Tensor from torch.nn import Module from colossalai.shardformer.layer import ( - CohereLayerNorm, - FusedCohereLayerNorm, + FusedLayerNorm, + LayerNorm, Linear1D_Col, Linear1D_Row, PaddingEmbedding, @@ -64,9 +64,9 @@ class CommandPolicy(Policy): embedding_cls = PaddingEmbedding if self.shard_config.enable_fused_normalization: - norm_cls = FusedCohereLayerNorm + norm_cls = FusedLayerNorm else: - norm_cls = CohereLayerNorm + norm_cls = LayerNorm if self.pipeline_stage_manager is not None: self.shard_config.enable_sequence_parallelism = False From 3c7302ad0ef6ae9d0cf973c441c5df68067a315c Mon Sep 17 00:00:00 2001 From: GuangyaoZhang Date: Mon, 17 Jun 2024 08:50:05 +0000 Subject: [PATCH 7/7] merge model and attention forward --- colossalai/shardformer/modeling/command.py | 270 +++------------------ colossalai/shardformer/policies/command.py | 24 +- 2 files changed, 52 insertions(+), 242 deletions(-) diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py index 85cf551b6..27021724c 100644 --- a/colossalai/shardformer/modeling/command.py +++ b/colossalai/shardformer/modeling/command.py @@ -1,5 +1,4 @@ import math -import warnings from typing import List, Optional, Tuple, Union import torch @@ -334,217 +333,6 @@ class CommandPipelineForwards: return {"hidden_states": hidden_states} -def get_command_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size): - from transformers.models.cohere.modeling_cohere import CohereAttention, apply_rotary_pos_emb, repeat_kv - - def forward( - self: CohereAttention, - hidden_states: torch.Tensor, - attention_mask: Optional[dict] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - bsz, q_len, _ = hidden_states.size() - - if sp_mode in ["split_gather", "ring"]: - q_len *= sp_size - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # sp: all-to-all comminucation when introducing sequence parallel - if sp_mode == "all_to_all": - query_states = all_to_all_comm(query_states, sp_group) - key_states = all_to_all_comm(key_states, sp_group) - value_states = all_to_all_comm(value_states, sp_group) - bsz, q_len, _ = query_states.size() - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." - attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - # sp: all-to-all comminucation when introducing sequence parallel - if sp_mode == "all_to_all": - attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - return forward - - -def get_command_model_forward_for_flash_attn(shard_config: ShardConfig): - logger = logging.get_logger(__name__) - assert shard_config.enable_flash_attention, "Flash Attention is not enabled." - - def forward( - self: CohereModel, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" - ) - - if self.gradient_checkpointing and self.training and use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." - ) - use_cache = False - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - past_seen_tokens = 0 - if use_cache: # kept for BC (cache positions) - if not isinstance(past_key_values, StaticCache): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_seen_tokens = past_key_values.get_seq_length() - if cache_position is None: - if isinstance(past_key_values, StaticCache): - raise ValueError("cache_position is a required argument when using StaticCache.") - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) - - if position_ids is None: - position_ids = cache_position.unsqueeze(0) - - # embed positions - hidden_states = inputs_embeds - - # in this case, attention_mask is a dict rather than a tensor - mask_shape = (hidden_states.shape[0], 1, past_seen_tokens, past_seen_tokens) - attention_mask = ColoAttention.prepare_attn_kwargs( - mask_shape, - hidden_states.dtype, - hidden_states.device, - q_padding_mask=attention_mask, - is_causal=True, - ) - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = None - if use_cache: - next_cache = ( - next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache - ) - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - return forward - - def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): from transformers import CohereForCausalLM @@ -647,7 +435,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): return forward -def get_command_seq_parallel_attention_forward(sp_mode, sp_size, sp_group): +def get_command_seq_parallel_attention_forward(sp_mode, sp_size, sp_group, use_flash_attention): from transformers.models.cohere.modeling_cohere import apply_rotary_pos_emb def forward( @@ -692,41 +480,43 @@ def get_command_seq_parallel_attention_forward(sp_mode, sp_size, sp_group): # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) + if use_flash_attention: + assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." + attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + else: + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) - attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": - attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) - else: - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - attn_output = self.o_proj(attn_output) - if not output_attentions: + if not output_attentions or use_flash_attention: attn_weights = None return attn_output, attn_weights, past_key_value return forward -def get_command_seq_parallel_model_forward(sp_mode, sp_size, sp_group): +def get_command_seq_parallel_model_forward(sp_mode, sp_size, sp_group, use_flash_attention): logger = logging.get_logger(__name__) def forward( @@ -779,8 +569,18 @@ def get_command_seq_parallel_model_forward(sp_mode, sp_size, sp_group): ) if position_ids is None: position_ids = cache_position.unsqueeze(0) - - attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) + if use_flash_attention: + hidden_states = inputs_embeds + mask_shape = (hidden_states.shape[0], 1, past_seen_tokens, past_seen_tokens) + attention_mask = ColoAttention.prepare_attn_kwargs( + mask_shape, + hidden_states.dtype, + hidden_states.device, + q_padding_mask=attention_mask, + is_causal=True, + ) + else: + attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) if sp_mode in ["ring", "split_gather"]: inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) diff --git a/colossalai/shardformer/policies/command.py b/colossalai/shardformer/policies/command.py index e2a367f74..5284c89f0 100644 --- a/colossalai/shardformer/policies/command.py +++ b/colossalai/shardformer/policies/command.py @@ -19,8 +19,6 @@ from colossalai.shardformer.layer import ( from ..modeling.command import ( CommandPipelineForwards, - get_command_flash_attention_forward, - get_command_model_forward_for_flash_attn, get_command_seq_parallel_attention_forward, get_command_seq_parallel_model_forward, get_lm_forward_with_dist_cross_entropy, @@ -95,7 +93,10 @@ class CommandPolicy(Policy): self.append_or_create_method_replacement( description={ "forward": get_command_seq_parallel_model_forward( - sp_mode=sp_mode, sp_size=sp_size, sp_group=sp_group + sp_mode=sp_mode, + sp_size=sp_size, + sp_group=sp_group, + use_flash_attention=use_flash_attention, ), }, policy=policy, @@ -103,7 +104,9 @@ class CommandPolicy(Policy): ) self.append_or_create_method_replacement( description={ - "forward": get_command_seq_parallel_attention_forward(sp_mode, sp_size, sp_group), + "forward": get_command_seq_parallel_attention_forward( + sp_mode, sp_size, sp_group, use_flash_attention=use_flash_attention + ), }, policy=policy, target_key=attn_cls, @@ -120,7 +123,9 @@ class CommandPolicy(Policy): ) self.append_or_create_method_replacement( description={ - "forward": get_command_seq_parallel_attention_forward(sp_mode, sp_size, sp_group), + "forward": get_command_seq_parallel_attention_forward( + sp_mode, sp_size, sp_group, use_flash_attention=use_flash_attention + ), }, policy=policy, target_key=attn_cls, @@ -131,6 +136,7 @@ class CommandPolicy(Policy): sp_mode=sp_mode, sp_size=sp_size, sp_group=sp_group, + use_flash_attention=use_flash_attention, ), }, policy=policy, @@ -234,7 +240,9 @@ class CommandPolicy(Policy): if use_flash_attention: self.append_or_create_method_replacement( description={ - "forward": get_command_flash_attention_forward(self.shard_config, sp_mode, sp_group, sp_size), + "forward": get_command_seq_parallel_attention_forward( + sp_mode, sp_group, sp_size, use_flash_attention=use_flash_attention + ), }, policy=policy, target_key=attn_cls, @@ -243,7 +251,9 @@ class CommandPolicy(Policy): # replace Command model forward method self.append_or_create_method_replacement( description={ - "forward": get_command_model_forward_for_flash_attn(self.shard_config), + "forward": get_command_seq_parallel_model_forward( + sp_mode, sp_size, sp_group, use_flash_attention=use_flash_attention + ), }, policy=policy, target_key=CohereModel,