diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index 0c3a55905..adf4501bb 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -2,7 +2,7 @@ name: Build on PR on: pull_request: - types: [synchronize, opened, reopened, ready_for_review, closed, edited] + types: [synchronize, opened, reopened, ready_for_review, closed] branches: - "main" - "develop" diff --git a/README.md b/README.md index e41b75c46..12d29727b 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ ## Latest News +* [2024/06] [Open-Sora Continues Open Source: Generate Any 16-Second 720p HD Video with One Click, Model Weights Ready to Use](https://hpc-ai.com/blog/open-sora-from-hpc-ai-tech-team-continues-open-source-generate-any-16-second-720p-hd-video-with-one-click-model-weights-ready-to-use) * [2024/05] [Large AI Models Inference Speed Doubled, Colossal-Inference Open Source Release](https://hpc-ai.com/blog/colossal-inference) * [2024/04] [Open-Sora Unveils Major Upgrade: Embracing Open Source with Single-Shot 16-Second Video Generation and 720p Resolution](https://hpc-ai.com/blog/open-soras-comprehensive-upgrade-unveiled-embracing-16-second-video-generation-and-720p-resolution-in-open-source) * [2024/04] [Most cost-effective solutions for inference, fine-tuning and pretraining, tailored to LLaMA3 series](https://hpc-ai.com/blog/most-cost-effective-solutions-for-inference-fine-tuning-and-pretraining-tailored-to-llama3-series) @@ -32,10 +33,6 @@ * [2024/03] [Open-Sora: Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models](https://hpc-ai.com/blog/open-sora-v1.0) * [2024/03] [Open-Sora:Sora Replication Solution with 46% Cost Reduction, Sequence Expansion to Nearly a Million](https://hpc-ai.com/blog/open-sora) * [2024/01] [Inference Performance Improved by 46%, Open Source Solution Breaks the Length Limit of LLM for Multi-Round Conversations](https://hpc-ai.com/blog/Colossal-AI-SwiftInfer) -* [2024/01] [Construct Refined 13B Private Model With Just $5000 USD, Upgraded Colossal-AI Llama-2 Open Source](https://hpc-ai.com/blog/colossal-llama-2-13b) -* [2023/11] [Enhanced MoE Parallelism, Open-source MoE Model Training Can Be 9 Times More Efficient](https://www.hpc-ai.tech/blog/enhanced-moe-parallelism-open-source-moe-model-training-can-be-9-times-more-efficient) -* [2023/09] [One Half-Day of Training Using a Few Hundred Dollars Yields Similar Results to Mainstream Large Models, Open-Source and Commercial-Free Domain-Specific LLM Solution](https://www.hpc-ai.tech/blog/one-half-day-of-training-using-a-few-hundred-dollars-yields-similar-results-to-mainstream-large-models-open-source-and-commercial-free-domain-specific-llm-solution) -* [2023/09] [70 Billion Parameter LLaMA2 Model Training Accelerated by 195%](https://www.hpc-ai.tech/blog/70b-llama2-training) * [2023/07] [HPC-AI Tech Raises 22 Million USD in Series A Funding](https://www.hpc-ai.tech/blog/hpc-ai-tech-raises-22-million-usd-in-series-a-funding-to-fuel-team-expansion-and-business-growth) ## Table of Contents @@ -132,13 +129,13 @@ distributed training and inference in a few lines. [Open-Sora](https://github.com/hpcaitech/Open-Sora):Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models [[code]](https://github.com/hpcaitech/Open-Sora) -[[blog]](https://hpc-ai.com/blog/open-soras-comprehensive-upgrade-unveiled-embracing-16-second-video-generation-and-720p-resolution-in-open-source) -[[HuggingFace model weights]](https://huggingface.co/hpcai-tech/Open-Sora) +[[blog]](https://hpc-ai.com/blog/open-sora-from-hpc-ai-tech-team-continues-open-source-generate-any-16-second-720p-hd-video-with-one-click-model-weights-ready-to-use) +[[Model weights]](https://github.com/hpcaitech/Open-Sora?tab=readme-ov-file#model-weights) [[Demo]](https://github.com/hpcaitech/Open-Sora?tab=readme-ov-file#-latest-demo)
diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 45fe03003..fa3c3646a 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -999,7 +999,9 @@ class HybridParallelPlugin(PipelinePluginBase): ), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" if enable_sequence_parallelism: - self.sequence_parallelism_mode = sequence_parallelism_mode if sequence_parallelism_mode is not None else "1" + self.sequence_parallelism_mode = ( + sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all" + ) assert ( self.sequence_parallelism_mode in SUPPORT_SP_MODE ), f"Sequence parallelism mode {self.sequence_parallelism_mode} is not in the supported list {SUPPORT_SP_MODE}" @@ -1014,19 +1016,13 @@ class HybridParallelPlugin(PipelinePluginBase): self.sp_size = 1 self.dp_size = dist.get_world_size() // (tp_size * pp_size) elif self.sequence_parallelism_mode in ["all_to_all"]: - assert ( - tp_size == 1 - ), f"Sequence parallelism mode {self.sequence_parallelism_mode} cannot be used with tensor parallelism" - assert ( - pp_size == 1 - ), f"Sequence parallelism mode {self.sequence_parallelism_mode} cannot be used with pipeline parallelism" - self.sp_size = dist.get_world_size() if sp_size is None else sp_size - self.dp_size = dist.get_world_size() // (self.sp_size * pp_size) + self.sp_size = 1 if sp_size is None else sp_size + self.dp_size = dist.get_world_size() // (self.sp_size * pp_size * tp_size) else: self.dp_size = dist.get_world_size() // (tp_size * pp_size) assert ( sp_size == 1 or sp_size is None - ), f"sp_size can only be set to a >1 number when enable_sequence_parallelism is True" + ), f"You should not set sp_size when sequence parallelism is not enabled." self.sp_size = 1 self.tp_size = tp_size @@ -1040,11 +1036,22 @@ class HybridParallelPlugin(PipelinePluginBase): self.enable_jit_fused = enable_jit_fused self.enable_sequence_parallelism = enable_sequence_parallelism if dp_outside: - self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 + ( + self.dp_axis, + self.pp_axis, + self.tp_axis, + self.sp_axis, + ) = ( + 0, + 1, + 2, + 3, + ) self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size) else: self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size) + self.stage_manager = None self.schedule = None self.custom_policy = custom_policy diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 6197be9d1..20870a3c2 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -314,7 +314,7 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors use_safetensors (bool): whether to use safetensors to save the checkpoint. """ # Move all tensors in the state_dict to CPU before saving to avoid serialization issues - state_dict_cpu = tree_map(lambda x: x.cpu() if torch.is_tensor(x) else x, state_dict) + state_dict_cpu = tree_map(lambda x: x.data.cpu() if torch.is_tensor(x) else x, state_dict) if use_safetensors: assert is_safetensors_available(), "safetensors is not available." diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md index ec4044127..0a9b5293d 100644 --- a/colossalai/inference/README.md +++ b/colossalai/inference/README.md @@ -236,7 +236,7 @@ Completion api is used for single sequence request, like answer a question or co - POST '/chat': Chat api is used for conversation-style request, which often includes dialogue participants(i.e. roles) and corresponding words. Considering the input data are very different from normal inputs, we introduce Chat-Template to match the data format in chat models. #### chat-template -Followed `transformers`, we add the chat-template argument. As chat models have been trained with very different formats for converting conversations into a single tokenizable string. Using a format that matches the training data is extremely important. This attribute(chat_template) is inclueded in HuggingFace tokenizers, containing a Jinja template that converts conversation histories into a correctly formatted string. You can refer to the [HuggingFace-blog](https://huggingface.co/blog/chat-templates) for more information. We also provide a simple example temlate bellow. Both str or file style chat template are supported. +Followed `transformers`, we add the chat-template argument. As chat models have been trained with very different formats for converting conversations into a single tokenizable string. Using a format that matches the training data is extremely important. This attribute(chat_template) is inclueded in HuggingFace tokenizers, containing a Jinja template that converts conversation histories into a correctly formatted string. You can refer to the [HuggingFace-blog](https://huggingface.co/blog/chat-templates) for more information. We also provide a simple example template bellow. Both str or file style chat template are supported. ### Usage #### Args for customizing your server The configuration for api server contains both serving interface and engine backend. diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index 9cf9a65e6..c73ee9df4 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -10,6 +10,7 @@ import torch from transformers.generation import GenerationConfig from colossalai.inference.flash_decoding_utils import FDIntermTensors +from colossalai.inference.utils import can_use_flash_attn2 GibiByte = 1024**3 @@ -169,7 +170,8 @@ class InferenceConfig(RPC_PARAM): no_repeat_ngram_size (Optional[int]): If no_repeat_ngram_size > 0, the consecutive tokens of ngram size can only appear once in inference sentences. repetition_penalty (Optional[float]): The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0. ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token. - n_spec_tokens (int): The maximum number of speculating tokens, defaults to None. + use_spec_dec (bool): Indicate whether to use speculative decoding, defaults to False. + max_n_spec_tokens (int): The maximum number of speculating tokens, defaults to None. glimpse_large_kv (bool): Whether to use large KV in drafter model, defaults to False. block_size (int): The number of blocks in a logical block, defaults to 16. tp_size (int): Tensor parallel size, defaults to 1. @@ -214,6 +216,7 @@ class InferenceConfig(RPC_PARAM): ignore_eos: bool = False # speculative decoding configs + use_spec_dec: bool = False max_n_spec_tokens: int = 5 glimpse_large_kv: bool = False @@ -311,6 +314,16 @@ class InferenceConfig(RPC_PARAM): return GenerationConfig.from_dict(meta_config) + def to_model_shard_inference_config(self) -> "ModelShardInferenceConfig": + use_flash_attn = can_use_flash_attn2(self.dtype) + model_inference_config = ModelShardInferenceConfig( + dtype=self.dtype, + use_cuda_kernel=self.use_cuda_kernel, + use_spec_dec=self.use_spec_dec, + use_flash_attn=use_flash_attn, + ) + return model_inference_config + def to_rpc_param(self) -> dict: kwargs = { "dtype": str(self.dtype).split(".")[-1], @@ -362,3 +375,21 @@ class InferenceConfig(RPC_PARAM): # Set the attributes from the parsed arguments. inference_config = cls(**inference_config_args) return inference_config + + +@dataclass +class ModelShardInferenceConfig: + """ + Configurations used during init of module for inference modeling. + + Args: + dtype (torch.dtype): The data type for weights and activations. + use_cuda_kernel (bool): Whether to use cuda kernel, faster but lose some precision occasionally + use_spec_dec (bool): Indicate whether to use speculative decoding. + use_flash_attn (bool): Indicate whether to use flash attention. + """ + + dtype: torch.dtype = None + use_cuda_kernel: bool = False + use_spec_dec: bool = False + use_flash_attn: bool = False diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 1b6e62553..a1b54fa1c 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -18,7 +18,7 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM from colossalai.accelerator import get_accelerator from colossalai.cluster import ProcessGroupMesh from colossalai.inference.batch_bucket import BatchBucket -from colossalai.inference.config import InferenceConfig, InputMetaData +from colossalai.inference.config import InferenceConfig, InputMetaData, ModelShardInferenceConfig from colossalai.inference.graph_runner import CUDAGraphRunner from colossalai.inference.modeling.policy import model_policy_map from colossalai.inference.sampler import search_tokens @@ -72,8 +72,9 @@ class InferenceEngine: self.verbose = verbose self.logger = get_dist_logger(__name__) + self.model_shard_infer_config = inference_config.to_model_shard_inference_config() - self.init_model(model_or_path, model_policy) + self.init_model(model_or_path, model_policy, self.model_shard_infer_config) self.generation_config = inference_config.to_generation_config(self.model_config) self.generation_config_dict = self.generation_config.to_dict() @@ -97,7 +98,8 @@ class InferenceEngine: self.capture_model(self.k_cache, self.v_cache) # Model and relatable attrs of speculative decoding will be set by `enable_spec_dec` - self.use_spec_dec = False + self.use_spec_dec = self.inference_config.use_spec_dec + self.drafter_model = None self.drafter = None self.use_glide = False @@ -105,13 +107,20 @@ class InferenceEngine: self._verify_args() - def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Union[Policy, Type[Policy]] = None): + def init_model( + self, + model_or_path: Union[nn.Module, str], + model_policy: Union[Policy, Type[Policy]] = None, + model_shard_infer_config: ModelShardInferenceConfig = None, + ): """ Shard model or/and Load weight Args: model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format. - model_policy (Policy): the policy to replace the model + model_policy (Policy): the policy to replace the model. + model_inference_config: the configuration for modeling initialization when inference. + model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference. """ if isinstance(model_or_path, str): @@ -124,6 +133,7 @@ class InferenceEngine: # the model load process in the future. model = _supported_models[arch].from_pretrained(model_or_path, trust_remote_code=True) else: + # TODO(char-1ee): if the model not supported, use transformers APIs to load and generate raise ValueError(f"Model {arch} is not supported.") except Exception as e: @@ -167,6 +177,7 @@ class InferenceEngine: self.model = self._shardformer( model, model_policy, + model_shard_infer_config, None, tp_group=tp_group, ) @@ -187,7 +198,7 @@ class InferenceEngine: # assert if_has_index_file, "the model path is invalid" # cpt_io.load_model(self.model, model_index_file) - free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() + free_gpu_memory, _ = torch.cuda.mem_get_info() peak_memory = init_gpu_memory - free_gpu_memory if self.verbose: self.logger.info( @@ -287,6 +298,7 @@ class InferenceEngine: self, model: nn.Module, model_policy: Policy, + model_shard_infer_config: ModelShardInferenceConfig = None, stage_manager: PipelineStageManager = None, tp_group: ProcessGroupMesh = None, ) -> nn.Module: @@ -312,6 +324,7 @@ class InferenceEngine: enable_flash_attention=False, enable_jit_fused=False, enable_sequence_parallelism=False, + extra_kwargs={"model_shard_infer_config": model_shard_infer_config}, ) shardformer = ShardFormer(shard_config=shardconfig) shard_model, _ = shardformer.optimize(model, model_policy) @@ -348,6 +361,7 @@ class InferenceEngine: engine.clear_spec_dec() ``` """ + if drafter_model is None and self.drafter is None: raise ValueError("Drafter not initialized. Please provide a Drafter Model") if n_spec_tokens is not None: @@ -452,6 +466,7 @@ class InferenceEngine: self.k_cache[-1], # use kv cahces of the last layer self.v_cache[-1], batch.get_sequence_lengths(), + n_spec_tokens=self.n_spec_tokens, ) drafter_out = self.drafter.speculate( @@ -517,19 +532,19 @@ class InferenceEngine: prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, return_token_ids: bool = False, generation_config: Optional[GenerationConfig] = None, - ) -> List[str]: + ) -> Union[List[str], Tuple[List[str], List[List[int]]]]: """ Executing the inference step. Args: - prompts (Union[List[str], optional): Input prompts. Defaults to None. - prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None. request_ids (List[int], optional): The request ID. Defaults to None. - return_token_ids (bool): Whether to return output token ids. Defaults to False. - generation_config (GenerationConfig, optional): Huggingface GenerationConfig used for inference. Defaults to None. + prompts (Union[List[str], optional): Input prompts. Defaults to None. + prompts_token_ids (Union[List[int], torch.Tensor, np.ndarray], optional): token ids of input prompts. Defaults to None. + return_token_ids (bool, optional): Whether to return output token ids. Defaults to False. + generation_config (Optional[GenerationConfig], optional): Huggingface GenerationConfig used for inference. Defaults to None. Returns: - List[str]: Inference result returned by one generation. + Union[List[str], Tuple[List[str], List[List[int]]]]: Inference result returned by one generation. """ gen_config_dict = generation_config.to_dict() if generation_config is not None else {} diff --git a/colossalai/inference/modeling/backends/__init__.py b/colossalai/inference/modeling/backends/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/colossalai/inference/modeling/backends/attention_backend.py b/colossalai/inference/modeling/backends/attention_backend.py new file mode 100644 index 000000000..ab586f510 --- /dev/null +++ b/colossalai/inference/modeling/backends/attention_backend.py @@ -0,0 +1,170 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass + +import torch + +from colossalai.inference.config import ModelShardInferenceConfig +from colossalai.kernel.kernel_loader import InferenceOpsLoader +from colossalai.kernel.triton import context_attention_unpadded, flash_decoding_attention + + +@dataclass +class AttentionMetaData: + query_states: torch.Tensor + key_states: torch.Tensor + value_states: torch.Tensor + k_cache: torch.Tensor + v_cache: torch.Tensor + block_tables: torch.Tensor + block_size: int + kv_seq_len: int = None + sequence_lengths: torch.Tensor = None + cu_seqlens: torch.Tensor = None + sm_scale: int = None + alibi_slopes: torch.Tensor = None + output_tensor: torch.Tensor = None + use_spec_dec: bool = False + use_alibi_attn: bool = False + + +class AttentionBackend(ABC): + @abstractmethod + def prefill(self, attn_metadata: AttentionMetaData, **kwargs): + raise NotImplementedError + + @abstractmethod + def decode(self, attn_metadatas: AttentionMetaData, **kwargs): + raise NotImplementedError + + +class CudaAttentionBackend(AttentionBackend): + """ + Attention backend when use_cuda_kernel is True but flash-attn not found. If flash-attn is not found, + it uses Triton op `context_attention_unpadded` for prefilling and our cuda op `flash_decoding_attention` for decoding. + """ + + def __init__(self, use_flash_attn: bool = False): + super().__init__() + self.inference_ops = InferenceOpsLoader().load() + self.use_flash_attn = use_flash_attn + + def prefill(self, attn_metadata: AttentionMetaData, **kwargs): + if self.use_flash_attn: + token_nums = kwargs.get("token_nums", -1) + + from flash_attn import flash_attn_varlen_func + + attn_output = flash_attn_varlen_func( + attn_metadata.query_states, + attn_metadata.key_states, + attn_metadata.value_states, + cu_seqlens_q=attn_metadata.cu_seqlens, + cu_seqlens_k=attn_metadata.cu_seqlens, + max_seqlen_q=attn_metadata.kv_seq_len, + max_seqlen_k=attn_metadata.kv_seq_len, + dropout_p=0.0, + softmax_scale=attn_metadata.sm_scale, + causal=True, + alibi_slopes=attn_metadata.alibi_slopes, + ) + attn_output = attn_output.view(token_nums, -1) + else: + attn_output = context_attention_unpadded( + q=attn_metadata.query_states, + k=attn_metadata.key_states, + v=attn_metadata.value_states, + k_cache=attn_metadata.k_cache, + v_cache=attn_metadata.v_cache, + context_lengths=attn_metadata.sequence_lengths, + block_tables=attn_metadata.block_tables, + block_size=attn_metadata.block_size, + output=attn_metadata.output_tensor, + alibi_slopes=attn_metadata.alibi_slopes, + max_seq_len=attn_metadata.kv_seq_len, + sm_scale=attn_metadata.sm_scale, + use_new_kcache_layout=True, # use new k-cache layout + ) + return attn_output + + def decode(self, attn_metadata: AttentionMetaData, **kwargs): + fd_inter_tensor = kwargs.get("fd_inter_tensor", None) + output_tensor = attn_metadata.output_tensor + self.inference_ops.flash_decoding_attention( + output_tensor, + attn_metadata.query_states, + attn_metadata.k_cache, + attn_metadata.v_cache, + attn_metadata.sequence_lengths, + attn_metadata.block_tables, + attn_metadata.block_size, + attn_metadata.kv_seq_len, + fd_inter_tensor.mid_output, + fd_inter_tensor.exp_sums, + fd_inter_tensor.max_logits, + attn_metadata.alibi_slopes, + attn_metadata.sm_scale, + ) + return output_tensor + + +class TritonAttentionBackend(AttentionBackend): + """ + Attention backend when use_cuda_kernel is False. It uses pure Triton ops for prefilling and decoding. + """ + + def prefill(self, attn_metadata: AttentionMetaData, **kwargs): + return context_attention_unpadded( + q=attn_metadata.query_states, + k=attn_metadata.key_states, + v=attn_metadata.value_states, + k_cache=attn_metadata.k_cache, + v_cache=attn_metadata.v_cache, + context_lengths=attn_metadata.sequence_lengths, + block_tables=attn_metadata.block_tables, + block_size=attn_metadata.block_size, + output=attn_metadata.output_tensor, + alibi_slopes=attn_metadata.alibi_slopes, + max_seq_len=attn_metadata.kv_seq_len, + sm_scale=attn_metadata.sm_scale, + ) + + def decode(self, attn_metadata: AttentionMetaData, **kwargs): + fd_inter_tensor = kwargs.get("fd_inter_tensor", None) + return flash_decoding_attention( + q=attn_metadata.query_states, + k_cache=attn_metadata.k_cache, + v_cache=attn_metadata.v_cache, + kv_seq_len=attn_metadata.sequence_lengths, + block_tables=attn_metadata.block_tables, + block_size=attn_metadata.block_size, + max_seq_len_in_batch=attn_metadata.kv_seq_len, + output=attn_metadata.output_tensor, + mid_output=fd_inter_tensor.mid_output, + mid_output_lse=fd_inter_tensor.mid_output_lse, + alibi_slopes=attn_metadata.alibi_slopes, + sm_scale=attn_metadata.sm_scale, + kv_group_num=kwargs.get("num_key_value_groups", 1), + q_len=kwargs.get("q_len", 1), + ) + + +def get_attention_backend( + model_shard_infer_config: ModelShardInferenceConfig, +) -> AttentionBackend: + """ + Get the attention backend based on the inference configurations. The modeling will use CUDA-kernel-based backend + for attention module calculation only when: + 1. using CUDA kernel (use_cuda_kernel=True) + 2. can use flash attention (flash-attn installed and dtype is fp16 or bf16) + 3. not using speculative decoding (currently cuda kernel not support speculative decoding) + Otherwise, use Triton attention backend. If found flash-attn not installed while `use_cuda_kernel` is True, + the Triton backend will use a new k cache layout for Triton kernels. + """ + # Currently only triton kernels support speculative decoding + if model_shard_infer_config.use_spec_dec: + return TritonAttentionBackend() + + if model_shard_infer_config.use_cuda_kernel: + return CudaAttentionBackend(model_shard_infer_config.use_flash_attn) + + return TritonAttentionBackend() diff --git a/colossalai/inference/modeling/backends/pre_attention_backend.py b/colossalai/inference/modeling/backends/pre_attention_backend.py new file mode 100644 index 000000000..77804429d --- /dev/null +++ b/colossalai/inference/modeling/backends/pre_attention_backend.py @@ -0,0 +1,146 @@ +from abc import ABC, abstractmethod + +from colossalai.inference.config import ModelShardInferenceConfig +from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData +from colossalai.kernel.kernel_loader import InferenceOpsLoader +from colossalai.kernel.triton import copy_k_to_blocked_cache, decoding_fused_rotary_embedding, rotary_embedding + + +class PreAttentionBackend(ABC): + @abstractmethod + def prefill(self, attn_metadata: AttentionMetaData, **kwargs): + raise NotImplementedError + + @abstractmethod + def decode(self, attn_metadata: AttentionMetaData, **kwargs): + raise NotImplementedError + + +class CudaPreAttentionBackend(PreAttentionBackend): + """ + CudaPreAttentionBackend handles KV cache initialization and positional encoding for CudaAttentionBackend. + """ + + def __init__(self, use_flash_attn: bool): + super().__init__() + self.inference_ops = InferenceOpsLoader().load() + self.use_flash_attn = use_flash_attn + + def prefill(self, attn_metadata: AttentionMetaData, **kwargs): + if self.use_flash_attn: + if not attn_metadata.use_alibi_attn: + self.inference_ops.rotary_embedding( + attn_metadata.query_states, + attn_metadata.key_states, + kwargs.get("cos", None), + kwargs.get("sin", None), + kwargs.get("high_precision", False), + ) + self.inference_ops.context_kv_cache_memcpy( + attn_metadata.key_states, + attn_metadata.value_states, + attn_metadata.k_cache, + attn_metadata.v_cache, + attn_metadata.sequence_lengths, + attn_metadata.cu_seqlens, + attn_metadata.block_tables, + attn_metadata.kv_seq_len, + ) + elif not attn_metadata.use_alibi_attn: + rotary_embedding( + attn_metadata.query_states, + attn_metadata.key_states, + kwargs.get("cos", None), + kwargs.get("sin", None), + ) + + def decode(self, attn_metadata: AttentionMetaData, **kwargs): + if not attn_metadata.use_alibi_attn: + self.inference_ops.rotary_embedding_and_cache_copy( + attn_metadata.query_states, + attn_metadata.key_states, + attn_metadata.value_states, + kwargs.get("cos", None), + kwargs.get("sin", None), + attn_metadata.k_cache, + attn_metadata.v_cache, + attn_metadata.sequence_lengths, + attn_metadata.block_tables, + kwargs.get("high_precision", None), + ) + else: + self.inference_ops.decode_kv_cache_memcpy( + attn_metadata.key_states, + attn_metadata.value_states, + attn_metadata.k_cache, + attn_metadata.v_cache, + attn_metadata.sequence_lengths, + attn_metadata.block_tables, + ) + + +class TritonPreAttentionBackend(PreAttentionBackend): + """ + TritonPreAttentionBackend handles KV cache initialization and positional encoding for TritonAttentionBackend. + """ + + def prefill(self, attn_metadata: AttentionMetaData, **kwargs): + if not attn_metadata.use_alibi_attn: + rotary_embedding( + attn_metadata.query_states, + attn_metadata.key_states, + kwargs.get("cos", None), + kwargs.get("sin", None), + ) + + def decode(self, attn_metadata: AttentionMetaData, **kwargs): + if not attn_metadata.use_spec_dec and not attn_metadata.use_alibi_attn: + decoding_fused_rotary_embedding( + attn_metadata.query_states, + attn_metadata.key_states, + attn_metadata.value_states, + kwargs.get("cos", None), + kwargs.get("sin", None), + attn_metadata.k_cache, + attn_metadata.v_cache, + attn_metadata.block_tables, + attn_metadata.sequence_lengths, + ) + else: # else if using speculative decoding + if not attn_metadata.use_alibi_attn: + rotary_embedding( + attn_metadata.query_states, + attn_metadata.key_states, + kwargs.get("cos", None), + kwargs.get("sin", None), + ) + copy_k_to_blocked_cache( + attn_metadata.key_states, + attn_metadata.k_cache, + kv_lengths=attn_metadata.sequence_lengths, + block_tables=attn_metadata.block_tables, + n=kwargs.get("q_len", 1), + ) + copy_k_to_blocked_cache( + attn_metadata.value_states, + attn_metadata.v_cache, + kv_lengths=attn_metadata.sequence_lengths, + block_tables=attn_metadata.block_tables, + n=kwargs.get("q_len", 1), + ) + + +def get_pre_attention_backend( + model_shard_infer_config: ModelShardInferenceConfig, +) -> PreAttentionBackend: + """ + Get the backend for pre-attention computations, including potisional encoding like + RoPE and KV cache initialization. It adopt the same selection logic as attention_backend/get_attention_backend. + """ + if model_shard_infer_config.use_spec_dec: + return TritonPreAttentionBackend() + + if model_shard_infer_config.use_cuda_kernel: + return CudaPreAttentionBackend(model_shard_infer_config.use_flash_attn) + + return TritonPreAttentionBackend() diff --git a/colossalai/inference/modeling/layers/baichuan_tp_linear.py b/colossalai/inference/modeling/layers/baichuan_tp_linear.py index e050dd71c..50806a14b 100644 --- a/colossalai/inference/modeling/layers/baichuan_tp_linear.py +++ b/colossalai/inference/modeling/layers/baichuan_tp_linear.py @@ -15,25 +15,10 @@ class BaichuanLMHeadLinear1D_Col(Linear1D_Col): module.in_features = module.weight.size(1) module.out_features = module.weight.size(0) module.bias = None - module.weight.data = nn.functional.normalize(module.weight) - - return Linear1D_Col.from_native_module( - module, - process_group, - *args, - **kwargs, - ) - - -class BaichuanWpackLinear1D_Col(Linear1D_Col): - @staticmethod - def from_native_module( - module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs - ) -> ParallelModule: - in_features = module.in_features * 3 - out_features = module.out_features // 3 - module.weight.data = module.weight.view(3, out_features, -1).transpose(0, 1).reshape(out_features, in_features) - module.bias = None + module.weight.data = nn.functional.normalize( + module.weight + ) # TODO(lry89757) This behavior may not apply to lazy init. When we use lazy init, the weight of shardformer is not the real weight. + # So we should rewrite our own load_from_state_dict of `BaichuanLMHeadLinear1D_Col` to fix this potential issue. return Linear1D_Col.from_native_module( module, diff --git a/colossalai/inference/modeling/models/glide_llama.py b/colossalai/inference/modeling/models/glide_llama.py index 7b25f3e74..0ee78a303 100644 --- a/colossalai/inference/modeling/models/glide_llama.py +++ b/colossalai/inference/modeling/models/glide_llama.py @@ -6,11 +6,7 @@ from typing import List, Optional, Tuple, Union import torch import torch.nn as nn -from transformers.cache_utils import Cache, DynamicCache -from transformers.modeling_attn_mask_utils import ( - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, -) +from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.llama.modeling_llama import ( LlamaAttention, @@ -137,6 +133,7 @@ def glide_llama_model_forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -147,57 +144,43 @@ def glide_llama_model_forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape[:2] - elif inputs_embeds is not None: - batch_size, seq_length = inputs_embeds.shape[:2] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - past_key_values_length = 0 - if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" ) - position_ids = position_ids.unsqueeze(0) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.") + use_cache = False if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._use_sdpa and not output_attentions: - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + past_seen_tokens = 0 + if use_cache: # kept for BC (cache positions) + if not isinstance(past_key_values, StaticCache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_seen_tokens = past_key_values.get_seq_length() + + if cache_position is None: + if isinstance(past_key_values, StaticCache): + raise ValueError("cache_position is a required argument when using StaticCache.") + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) + # embed positions hidden_states = inputs_embeds # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: @@ -212,6 +195,7 @@ def glide_llama_model_forward( past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] @@ -230,7 +214,9 @@ def glide_llama_model_forward( next_cache = None if use_cache: - next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + next_cache = ( + next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache + ) if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( @@ -333,7 +319,8 @@ class LlamaCrossAttention(nn.Module): query_states = query_states.view(bsz, -1, self.large_num_heads, self.large_head_dim).transpose(1, 2) # for RoPE - cos, sin = self.rotary_emb(query_states, seq_len=kv_seq_len + 32) + position_ids = position_ids + glide_input.n_spec_tokens + cos, sin = self.rotary_emb(query_states, position_ids) query_states = apply_single_rotary_pos_emb(query_states, cos, sin, position_ids) query_states = query_states.transpose(1, 2) query_states = query_states.reshape(-1, self.large_num_heads, self.large_head_dim) diff --git a/colossalai/inference/modeling/models/nopadding_baichuan.py b/colossalai/inference/modeling/models/nopadding_baichuan.py index b50e73d6f..3bab671c4 100644 --- a/colossalai/inference/modeling/models/nopadding_baichuan.py +++ b/colossalai/inference/modeling/models/nopadding_baichuan.py @@ -1,68 +1,27 @@ # This code is adapted from huggingface baichuan model: hhttps://huggingface.co/baichuan-inc/Baichuan2-13B-Base/blob/main/modeling_baichuan.py -import itertools -import math from typing import List, Optional, Tuple, Union import torch import torch.nn as nn from torch.distributed import ProcessGroup +from colossalai.accelerator import get_accelerator +from colossalai.inference.config import ModelShardInferenceConfig from colossalai.inference.flash_decoding_utils import FDIntermTensors +from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData, get_attention_backend +from colossalai.inference.modeling.backends.pre_attention_backend import get_pre_attention_backend from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaMLP +from colossalai.inference.utils import get_alibi_slopes from colossalai.kernel.kernel_loader import InferenceOpsLoader -from colossalai.kernel.triton import ( - context_attention_unpadded, - copy_k_to_blocked_cache, - decoding_fused_rotary_embedding, - flash_decoding_attention, - rms_layernorm, - rotary_embedding, -) +from colossalai.kernel.triton import rms_layernorm from colossalai.logging import get_dist_logger from colossalai.shardformer.layer.parallel_module import ParallelModule -from colossalai.tensor.d_tensor import Layout, distribute_tensor, is_distributed_tensor - -logger = get_dist_logger(__name__) - -try: - from flash_attn import flash_attn_varlen_func - - use_flash_attn2 = True -except ImportError: - use_flash_attn2 = False - logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.") - -logger = get_dist_logger(__name__) - -try: - from flash_attn import flash_attn_varlen_func - - use_flash_attn2 = True -except ImportError: - use_flash_attn2 = False - logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.") +from colossalai.tensor.d_tensor import is_distributed_tensor inference_ops = InferenceOpsLoader().load() - logger = get_dist_logger(__name__) -# alibi slopes calculation adapted from https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/models/bloom/modeling_bloom.py#L57 -def get_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor: - closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) - base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, device=device) - powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32, device=device) - slopes = torch.pow(base, powers) - if closest_power_of_2 != num_heads: - extra_base = torch.tensor( - 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, device=device - ) - num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) - extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32, device=device) - slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) - return slopes - - def baichuan_rmsnorm_forward( self, hidden_states: torch.Tensor, @@ -96,23 +55,19 @@ class NopadBaichuanAttention(ParallelModule): def __init__( self, config, - attn_qproj_w: torch.Tensor = None, - attn_kproj_w: torch.Tensor = None, - attn_vproj_w: torch.Tensor = None, + W_pack: ParallelModule = None, attn_oproj: ParallelModule = None, num_heads: int = None, hidden_size: int = None, + model_shard_infer_config: ModelShardInferenceConfig = None, process_group: ProcessGroup = None, - helper_layout: Layout = None, ): """This layer will replace the BaichuanAttention. Args: config (BaichuanConfig): Holding the Baichuan model config. - attn_qproj_w (torch.Tensor, optional): The transposed q_proj weight. Defaults to None. - attn_kproj_w (torch.Tensor, optional): The transposed k_proj weight. Defaults to None. - attn_vproj_w (torch.Tensor, optional): The transposed v_proj weight. Defaults to None. - attn_oproj (Linear1D_Row, optional): The Linear1D_Row o_proj weight. Defaults to None. + W_pack (ParallelModule, optional): The packed weight. Defaults to None. + attn_oproj (Linear1D_Row, optional): The Linear1D_Row o_proj. Defaults to None. """ ParallelModule.__init__(self) self.o_proj = attn_oproj @@ -122,10 +77,10 @@ class NopadBaichuanAttention(ParallelModule): self.hidden_size = hidden_size self.head_dim = self.hidden_size // self.num_heads self.process_group = process_group - qkv_weight_list = [attn_qproj_w.transpose(0, 1), attn_kproj_w.transpose(0, 1), attn_vproj_w.transpose(0, 1)] - self.qkv_weight = nn.Parameter(torch.stack(qkv_weight_list, dim=0)) - - self.helper_layout = helper_layout + self.W_pack = W_pack + self.use_cuda_kernel = model_shard_infer_config.use_cuda_kernel + self.attention_backend = get_attention_backend(model_shard_infer_config) + self.pre_attention_backend = get_pre_attention_backend(model_shard_infer_config) self.alibi_slopes = None self.use_alibi_attn = False @@ -133,9 +88,9 @@ class NopadBaichuanAttention(ParallelModule): if config.hidden_size == 5120: slopes_start = self.process_group.rank() * num_heads self.use_alibi_attn = True - self.alibi_slopes = get_alibi_slopes(config.num_attention_heads, device=attn_qproj_w.device)[ - slopes_start : slopes_start + num_heads - ].contiguous() + self.alibi_slopes = get_alibi_slopes( + config.num_attention_heads, device=get_accelerator().get_current_device() + )[slopes_start : slopes_start + num_heads].contiguous() self.alibi_slopes = nn.Parameter(self.alibi_slopes) @staticmethod @@ -149,76 +104,22 @@ class NopadBaichuanAttention(ParallelModule): """ config = module.config - q_proj_w, k_proj_w, v_proj_w = module.W_pack.weight.view((module.hidden_size, 3, -1)).transpose(0, 1) - - attn_qproj_w = q_proj_w - attn_kproj_w = k_proj_w - attn_vproj_w = v_proj_w + W_pack = module.W_pack attn_oproj = module.o_proj - - helper_layout = ( - module.W_pack.weight.dist_layout - ) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict) + model_shard_infer_config = kwargs.get("model_shard_infer_config", None) attn_layer = NopadBaichuanAttention( config=config, - attn_qproj_w=attn_qproj_w, - attn_kproj_w=attn_kproj_w, - attn_vproj_w=attn_vproj_w, + W_pack=W_pack, attn_oproj=attn_oproj, + model_shard_infer_config=model_shard_infer_config, num_heads=module.num_heads, hidden_size=module.hidden_size, process_group=process_group, - helper_layout=helper_layout, ) return attn_layer - def _load_from_state_dict( - self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ): - for hook in self._load_state_dict_pre_hooks.values(): - hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) - - persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} - local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) - local_state = {k: v for k, v in local_name_params if v is not None} - - key = "qkv_weight" - qkv_w = state_dict[prefix + "W_pack.weight"] - - in_features = qkv_w.size(1) - out_features = qkv_w.size(0) // 3 - - qkv_w.data = qkv_w.view((3, out_features, -1)).transpose(0, 1).reshape(out_features, in_features * 3) - - device_mesh = self.helper_layout.device_mesh - sharding_spec = self.helper_layout.sharding_spec - qkv_w = distribute_tensor(qkv_w, device_mesh, sharding_spec) - - qkv_w = qkv_w.transpose(0, 1).reshape(3, in_features, -1) - input_param = nn.Parameter( - qkv_w - ) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param) - - param = local_state[key] - - try: - with torch.no_grad(): - param.copy_(input_param) - except Exception as ex: - error_msgs.append( - 'While copying the parameter named "{}", ' - "whose dimensions in the model are {} and " - "whose dimensions in the checkpoint are {}, " - "an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args) - ) - - strict = False # to avoid unexpected_keys - super()._load_from_state_dict( - state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ) - def forward( self, hidden_states: torch.Tensor, @@ -234,7 +135,6 @@ class NopadBaichuanAttention(ParallelModule): kv_seq_len: int = 0, output_tensor: torch.Tensor = None, sm_scale: int = None, - use_cuda_kernel: bool = True, cu_seqlens: torch.Tensor = None, high_precision: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -253,144 +153,66 @@ class NopadBaichuanAttention(ParallelModule): kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0. output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None. sm_scale (int, optional): Used for flash attention. Defaults to None. - use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True. cu_seqlens(torch.Tensor, optional): Holding the cumulative sum of sequence length. high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False. """ - token_nums = hidden_states.size(0) - # fused qkv - hidden_states = hidden_states.expand(3, -1, -1) - query_states, key_states, value_states = ( - torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0) - ) + + proj = self.W_pack(hidden_states) + proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2) + query_states = proj[0].view(token_nums, self.num_heads, self.head_dim) + key_states = proj[1].view(token_nums, self.num_heads, self.head_dim) + value_states = proj[2].view(token_nums, self.num_heads, self.head_dim) block_size = k_cache.size(-2) - if is_prompts: - if not is_verifier and use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2: - # flash attn 2 currently only supports FP16/BF16. - if not self.use_alibi_attn: - inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision) - inference_ops.context_kv_cache_memcpy( - key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len - ) - attn_output = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=kv_seq_len, - max_seqlen_k=kv_seq_len, - dropout_p=0.0, - softmax_scale=sm_scale, - causal=True, - alibi_slopes=self.alibi_slopes, - ) - attn_output = attn_output.view(token_nums, -1) - else: - if not self.use_alibi_attn: - rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) - attn_output = context_attention_unpadded( - q=query_states, - k=key_states, - v=value_states, - k_cache=k_cache, - v_cache=v_cache, - context_lengths=sequence_lengths, - block_tables=block_tables, - block_size=block_size, - output=output_tensor, - alibi_slopes=self.alibi_slopes, - max_seq_len=kv_seq_len, - sm_scale=sm_scale, - use_new_kcache_layout=use_cuda_kernel, - ) - else: - q_len = tokens_to_verify + 1 if is_verifier else 1 + attn_metadata = AttentionMetaData( + query_states=query_states, + key_states=key_states, + value_states=value_states, + k_cache=k_cache, + v_cache=v_cache, + block_tables=block_tables, + block_size=block_size, + kv_seq_len=kv_seq_len, + sequence_lengths=sequence_lengths, + sm_scale=sm_scale, + alibi_slopes=self.alibi_slopes, + cu_seqlens=cu_seqlens, + output_tensor=output_tensor, + use_spec_dec=is_verifier, + use_alibi_attn=self.use_alibi_attn, + ) - if use_cuda_kernel: - if not self.use_alibi_attn: - inference_ops.rotary_embedding_and_cache_copy( - query_states, - key_states, - value_states, - cos_sin[0], - cos_sin[1], - k_cache, - v_cache, - sequence_lengths, - block_tables, - high_precision, - ) - else: - inference_ops.decode_kv_cache_memcpy( - key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables - ) - inference_ops.flash_decoding_attention( - output_tensor, - query_states, - k_cache, - v_cache, - sequence_lengths, - block_tables, - block_size, - kv_seq_len, - fd_inter_tensor.mid_output, - fd_inter_tensor.exp_sums, - fd_inter_tensor.max_logits, - self.alibi_slopes, - sm_scale, - ) - attn_output = output_tensor - else: - if not is_verifier and not self.use_alibi_attn: - decoding_fused_rotary_embedding( - query_states, - key_states, - value_states, - cos_sin[0], - cos_sin[1], - k_cache, - v_cache, - block_tables, - sequence_lengths, - ) - else: - if not self.use_alibi_attn: - rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) - copy_k_to_blocked_cache( - key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len - ) - copy_k_to_blocked_cache( - value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len - ) + if is_prompts: # prefilling stage + self.pre_attention_backend.prefill( + attn_metadata, + cos=cos_sin[0], + sin=cos_sin[1], + high_precision=high_precision, + ) + attn_output = self.attention_backend.prefill( + attn_metadata, + token_nums=token_nums, + ) + else: # decoding stage + q_len = tokens_to_verify + 1 if is_verifier else 1 - attn_output = flash_decoding_attention( - q=query_states, - k_cache=k_cache, - v_cache=v_cache, - kv_seq_len=sequence_lengths, - block_tables=block_tables, - block_size=block_size, - max_seq_len_in_batch=kv_seq_len, - output=output_tensor, - mid_output=fd_inter_tensor.mid_output, - mid_output_lse=fd_inter_tensor.mid_output_lse, - alibi_slopes=self.alibi_slopes, - sm_scale=sm_scale, - q_len=q_len, - ) + self.pre_attention_backend.decode( + attn_metadata, + q_len=q_len, + ) + attn_output = self.attention_backend.decode( + attn_metadata, + fd_inter_tensor=fd_inter_tensor, + q_len=q_len, + ) attn_output = attn_output.view(-1, self.hidden_size) attn_output = self.o_proj(attn_output) return attn_output - def extra_repr(self) -> str: - return f"qkv_weight_proj MergedLinear1D_Col: in_features={self.qkv_weight.shape[1]}x3, out_features={self.qkv_weight.shape[2]}, bias=False" - # NOTE This will cause difference as out length increases. class NopadBaichuanMLP(NopadLlamaMLP): diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index f6f160eb7..445ec59ce 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -16,18 +16,13 @@ from transformers.models.llama.modeling_llama import ( LlamaRMSNorm, ) -from colossalai.inference.config import InputMetaData +from colossalai.inference.config import InputMetaData, ModelShardInferenceConfig from colossalai.inference.flash_decoding_utils import FDIntermTensors +from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData, get_attention_backend +from colossalai.inference.modeling.backends.pre_attention_backend import get_pre_attention_backend +from colossalai.inference.utils import can_use_flash_attn2 from colossalai.kernel.kernel_loader import InferenceOpsLoader -from colossalai.kernel.triton import ( - context_attention_unpadded, - copy_k_to_blocked_cache, - decoding_fused_rotary_embedding, - flash_decoding_attention, - get_xine_cache, - rms_layernorm, - rotary_embedding, -) +from colossalai.kernel.triton import get_xine_cache, rms_layernorm from colossalai.logging import get_dist_logger from colossalai.shardformer.layer.parallel_module import ParallelModule from colossalai.tensor.d_tensor import distribute_tensor, is_distributed_tensor @@ -36,14 +31,6 @@ inference_ops = InferenceOpsLoader().load() logger = get_dist_logger(__name__) -try: - from flash_attn import flash_attn_varlen_func - - use_flash_attn2 = True -except ImportError: - use_flash_attn2 = False - logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.") - def llama_causal_lm_forward( self: LlamaForCausalLM, @@ -126,8 +113,8 @@ def llama_model_forward( cos_sin = (self._cos_cached[rotary_indexes], self._sin_cached[rotary_indexes]) elif use_cuda_kernel: - if inputmetadata.dtype != torch.float32 and use_flash_attn2: - cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0)) + if can_use_flash_attn2(inputmetadata.dtype): + cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.int32), (1, 0)) hidden_dim = self._cos_cached.size(-1) total_length = hidden_states.size(0) @@ -238,7 +225,6 @@ def llama_decoder_layer_forward( kv_seq_len=kv_seq_len, output_tensor=output_tensor, sm_scale=sm_scale, - use_cuda_kernel=use_cuda_kernel, cu_seqlens=cu_seqlens, high_precision=high_precision, ) @@ -279,7 +265,7 @@ class NopadLlamaMLP(LlamaMLP, ParallelModule): mlp_dproj: ParallelModule = None, process_group: ProcessGroup = None, ): - """A Unified Layer for + """Replacement of LlamaMLP layer. Args: config (LlamaConfig): Holding the Llama model config. @@ -402,6 +388,7 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule): attn_vproj_w: torch.Tensor = None, attn_oproj: ParallelModule = None, process_group: ProcessGroup = None, + model_shard_infer_config: ModelShardInferenceConfig = None, num_heads: int = None, hidden_size: int = None, num_key_value_heads: int = None, @@ -433,6 +420,9 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule): self.rope_theta = config.rope_theta self.is_causal = True + self.attention_backend = get_attention_backend(model_shard_infer_config) + self.pre_attention_backend = get_pre_attention_backend(model_shard_infer_config) + if self.num_heads == self.num_key_value_heads: qkv_weight_list = [attn_qproj_w.transpose(0, 1), attn_kproj_w.transpose(0, 1), attn_vproj_w.transpose(0, 1)] self.qkv_weight = nn.Parameter(torch.stack(qkv_weight_list, dim=0)) @@ -462,6 +452,7 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule): attn_vproj_w = module.v_proj.weight assert is_distributed_tensor(attn_qproj_w), "attn_qproj_w must be dist tensor" attn_oproj = module.o_proj + model_shard_infer_config = kwargs.get("model_shard_infer_config", None) attn_layer = NopadLlamaAttention( config=config, @@ -471,6 +462,7 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule): attn_vproj_w=attn_vproj_w, attn_oproj=attn_oproj, process_group=process_group, + model_shard_infer_config=model_shard_infer_config, num_heads=module.num_heads, hidden_size=module.hidden_size, num_key_value_heads=module.num_key_value_heads, @@ -533,111 +525,50 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule): block_size = k_cache.size(-2) - if is_prompts: - if not is_verifier and use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2: - # flash attn 2 currently only supports FP16/BF16. - inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision) - inference_ops.context_kv_cache_memcpy( - key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len - ) + attn_metadata = AttentionMetaData( + query_states=query_states, + key_states=key_states, + value_states=value_states, + k_cache=k_cache, + v_cache=v_cache, + block_tables=block_tables, + block_size=block_size, + kv_seq_len=kv_seq_len, + sequence_lengths=sequence_lengths, + sm_scale=sm_scale, + alibi_slopes=None, + cu_seqlens=cu_seqlens, + output_tensor=output_tensor, + use_spec_dec=is_verifier, + use_alibi_attn=False, + ) - attn_output = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=kv_seq_len, - max_seqlen_k=kv_seq_len, - dropout_p=0.0, - softmax_scale=sm_scale, - causal=True, - ) - attn_output = attn_output.view(token_nums, -1) - else: - rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) - attn_output = context_attention_unpadded( - q=query_states, - k=key_states, - v=value_states, - k_cache=k_cache, - v_cache=v_cache, - context_lengths=sequence_lengths, - block_tables=block_tables, - block_size=block_size, - output=output_tensor, - max_seq_len=kv_seq_len, - sm_scale=sm_scale, - use_new_kcache_layout=use_cuda_kernel, - ) - else: + if is_prompts: # prefilling stage + self.pre_attention_backend.prefill( + attn_metadata, + cos=cos_sin[0], + sin=cos_sin[1], + high_precision=high_precision, + ) + attn_output = self.attention_backend.prefill( + attn_metadata, + token_nums=token_nums, + ) + else: # decoding stage q_len = tokens_to_verify + 1 if is_verifier else 1 - if use_cuda_kernel: - inference_ops.rotary_embedding_and_cache_copy( - query_states, - key_states, - value_states, - cos_sin[0], - cos_sin[1], - k_cache, - v_cache, - sequence_lengths, - block_tables, - high_precision, - ) - inference_ops.flash_decoding_attention( - output_tensor, - query_states, - k_cache, - v_cache, - sequence_lengths, - block_tables, - block_size, - kv_seq_len, - fd_inter_tensor.mid_output, - fd_inter_tensor.exp_sums, - fd_inter_tensor.max_logits, - None, - sm_scale, - ) - attn_output = output_tensor - else: - if is_verifier: - rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) - copy_k_to_blocked_cache( - key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len - ) - copy_k_to_blocked_cache( - value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len - ) - else: - decoding_fused_rotary_embedding( - query_states, - key_states, - value_states, - cos_sin[0], - cos_sin[1], - k_cache, - v_cache, - block_tables, - sequence_lengths, - ) - attn_output = flash_decoding_attention( - q=query_states, - k_cache=k_cache, - v_cache=v_cache, - kv_seq_len=sequence_lengths, - block_tables=block_tables, - block_size=block_size, - max_seq_len_in_batch=kv_seq_len, - output=output_tensor, - mid_output=fd_inter_tensor.mid_output, - mid_output_lse=fd_inter_tensor.mid_output_lse, - sm_scale=sm_scale, - kv_group_num=self.num_key_value_groups, - q_len=q_len, - ) + self.pre_attention_backend.decode( + attn_metadata, + cos=cos_sin[0], + sin=cos_sin[1], + q_len=q_len, + ) + attn_output = self.attention_backend.decode( + attn_metadata, + fd_inter_tensor=fd_inter_tensor, + num_key_value_groups=self.num_key_value_groups, + q_len=q_len, + ) attn_output = attn_output.view(-1, self.hidden_size) attn_output = self.o_proj(attn_output) diff --git a/colossalai/inference/modeling/policy/nopadding_baichuan.py b/colossalai/inference/modeling/policy/nopadding_baichuan.py index 78268d6e7..37b5062e8 100644 --- a/colossalai/inference/modeling/policy/nopadding_baichuan.py +++ b/colossalai/inference/modeling/policy/nopadding_baichuan.py @@ -1,8 +1,5 @@ from colossalai.inference.config import RPC_PARAM -from colossalai.inference.modeling.layers.baichuan_tp_linear import ( - BaichuanLMHeadLinear1D_Col, - BaichuanWpackLinear1D_Col, -) +from colossalai.inference.modeling.layers.baichuan_tp_linear import BaichuanLMHeadLinear1D_Col from colossalai.inference.modeling.models.nopadding_baichuan import ( NopadBaichuanAttention, NopadBaichuanMLP, @@ -14,7 +11,7 @@ from colossalai.inference.modeling.models.nopadding_llama import ( llama_model_forward, ) from colossalai.inference.utils import init_to_get_rotary -from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row +from colossalai.shardformer.layer import FusedLinear1D_Col, Linear1D_Col, Linear1D_Row from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy @@ -60,8 +57,7 @@ class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM): target_module=NopadBaichuanMLP, ), SubModuleReplacementDescription( - suffix="self_attn.W_pack", - target_module=BaichuanWpackLinear1D_Col, + suffix="self_attn.W_pack", target_module=FusedLinear1D_Col, kwargs={"n_fused": 3} ), SubModuleReplacementDescription( suffix="self_attn.o_proj", @@ -70,6 +66,9 @@ class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM): SubModuleReplacementDescription( suffix="self_attn", target_module=NopadBaichuanAttention, + kwargs={ + "model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"], + }, ), ], ) diff --git a/colossalai/inference/modeling/policy/nopadding_llama.py b/colossalai/inference/modeling/policy/nopadding_llama.py index 24cf7c740..0b6797560 100644 --- a/colossalai/inference/modeling/policy/nopadding_llama.py +++ b/colossalai/inference/modeling/policy/nopadding_llama.py @@ -72,6 +72,9 @@ class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM): SubModuleReplacementDescription( suffix="self_attn", target_module=NopadLlamaAttention, + kwargs={ + "model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"], + }, ), ], ) diff --git a/colossalai/inference/spec/struct.py b/colossalai/inference/spec/struct.py index 143f26d09..9b52437db 100644 --- a/colossalai/inference/spec/struct.py +++ b/colossalai/inference/spec/struct.py @@ -46,6 +46,7 @@ class GlideInput: large_k_cache: torch.Tensor = None large_v_cache: torch.Tensor = None sequence_lengths: torch.Tensor = None + n_spec_tokens: int = 5 @property def glimpse_ready(self): diff --git a/colossalai/inference/utils.py b/colossalai/inference/utils.py index 072bedec3..8c155e6ca 100644 --- a/colossalai/inference/utils.py +++ b/colossalai/inference/utils.py @@ -1,6 +1,7 @@ """ Utils for model inference """ +import math import os import re from pathlib import Path @@ -9,8 +10,11 @@ from typing import Optional, Tuple import torch from torch import nn +from colossalai.logging import get_dist_logger from colossalai.testing import free_port +logger = get_dist_logger(__name__) + def init_to_get_rotary(self, base=10000, use_elem=False): """ @@ -113,3 +117,44 @@ def find_available_ports(num: int): print(f"An OS error occurred: {e}") raise RuntimeError("Error finding available ports") return free_ports + + +def get_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor: + """ + Alibi slopes calculation adapted from https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/models/bloom/modeling_bloom.py#L57 + + Args: + num_heads (int): The number of attention heads. + device (torch.device): The device to use. + + Returns: + torch.Tensor: The Alibi slopes. + """ + closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) + base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, device=device) + powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32, device=device) + slopes = torch.pow(base, powers) + if closest_power_of_2 != num_heads: + extra_base = torch.tensor( + 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, device=device + ) + num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) + extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32, device=device) + slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) + return slopes + + +def can_use_flash_attn2(dtype: torch.dtype) -> bool: + """ + Check flash attention2 availability. + """ + if dtype not in (torch.float16, torch.bfloat16): + return False + + try: + from flash_attn import flash_attn_varlen_func # noqa + + return True + except ImportError: + logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.") + return False diff --git a/colossalai/initialize.py b/colossalai/initialize.py index 934555e19..71d42312e 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -45,7 +45,10 @@ def launch( backend = cur_accelerator.communication_backend # init default process group - init_method = f"tcp://[{host}]:{port}" + if ":" in host: # IPv6 + init_method = f"tcp://[{host}]:{port}" + else: # IPv4 + init_method = f"tcp://{host}:{port}" dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method) # set cuda device diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index abc865a34..141baf3d3 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -50,7 +50,7 @@ def get_pad_info(padding_mask: torch.Tensor) -> Tuple[int, torch.Tensor, torch.T seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) return max_seqlen_in_batch, cu_seqlens, indices diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py index 5aa212600..59e1da9fc 100644 --- a/colossalai/shardformer/layer/normalization.py +++ b/colossalai/shardformer/layer/normalization.py @@ -140,32 +140,29 @@ class RMSNorm(BaseLayerNorm): class LayerNorm(BaseLayerNorm): r""" - This is a wrapper around the torch.nn.LayerNorm. It is meant to be used only with the from_native_module interface. + This is a wrapper around native LayerNorm. It is meant to be used only with the from_native_module interface. """ def __init__(self) -> None: raise NotImplementedError( "LayerNorm is not implemented as a physical class. " - "It is meant to be used only with the from_native_module interface to convert a native pytorch layer norm module to colossalai layer norm module." + "It is meant to be used only with the from_native_module interface to convert a native LayerNorm module to colossalai layer norm module." ) @staticmethod - def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module: + def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module: r""" - Convert a native pytorch layer norm module to colossalai layer norm module, + Convert a native LayerNorm module to colossalai layer norm module, and optionally marking parameters for gradient aggregation. Args: - module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted. + module (nn.Module): The native LayerNorm module to be converted. sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism. Returns: - nn.Module: The LayerNorm module. + nn.Module: The colossalai LayerNorm module. - Raises: - AssertionError: If the provided module is not an instance of nn.LayerNorm. """ - assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm." LazyInitContext.materialize(module) @@ -174,7 +171,8 @@ class LayerNorm(BaseLayerNorm): # aggregation of these gradients is necessary during backpropagation. # Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation. SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight) - SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias) + if module.bias is not None: + SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias) return module @@ -187,31 +185,29 @@ class FusedLayerNorm(BaseLayerNorm): def __init__(self) -> None: raise NotImplementedError( "FusedLayerNorm is not implemented as a physical class. " - "It is meant to be used only with the from_native_module interface convert a native pytorch layer norm module to FusedLayerNorm module provided by apex." + "It is meant to be used only with the from_native_module interface convert a native LayerNorm module to FusedLayerNorm module provided by apex." ) @staticmethod def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module: r""" - Convert a native pytorch layer norm module to FusedLayerNorm module provided by apex, + Convert a native LayerNorm module to FusedLayerNorm module provided by apex, and optionally marking parameters for gradient aggregation. Args: - module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted. + module (nn.Module): The native LayerNorm module to be converted. sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism. Returns: nn.Module: Union[FastLayerNorm, FusedLayerNorm]. - Raises: - AssertionError: If the provided module is not an instance of nn.LayerNorm. """ LazyInitContext.materialize(module) # get the attributes of the module - normalized_shape = module.normalized_shape - eps = module.eps - elementwise_affine = module.elementwise_affine + normalized_shape = getattr(module, "normalized_shape", module.weight.shape[0]) + eps = module.variance_epsilon if hasattr(module, "variance_epsilon") else module.eps + elementwise_affine = getattr(module, "elementwise_affine", True) dtype = module.weight.dtype device = module.weight.device @@ -229,7 +225,7 @@ class FusedLayerNorm(BaseLayerNorm): ApexFusedLayerNorm = FusedLayerNormWithHook except NameError: warnings.warn( - "Please install Apex from source to use fused kernels, or set self.enable_fused_normalization = False. Using vanilla layernorm instead." + "Please install Apex from source to use fused kernels, or set self.enable_fused_normalization = False. Using native layernorm instead." ) return module @@ -237,7 +233,8 @@ class FusedLayerNorm(BaseLayerNorm): ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device) ) layernorm.weight = module.weight - layernorm.bias = module.bias + if module.bias is not None: + layernorm.bias = module.bias if sp_partial_derived: # Since gradients are computed using only a subset of the data, diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index bf74d0833..1f34215c5 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -475,7 +475,10 @@ class BloomPipelineForwards: sequence_lengths = -1 else: if input_ids is not None: - sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) else: sequence_lengths = -1 logger.warning( diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py new file mode 100644 index 000000000..07a7f6cbf --- /dev/null +++ b/colossalai/shardformer/modeling/command.py @@ -0,0 +1,692 @@ +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.models.cohere.modeling_cohere import ( + CohereForCausalLM, + CohereModel, + StaticCache, + apply_rotary_pos_emb, + repeat_kv, +) +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer._operation import ( + all_to_all_comm, + gather_forward_split_backward, + split_forward_gather_backward, +) +from colossalai.shardformer.shard import ShardConfig + +from ..layer import ColoAttention, cross_entropy_1d + + +class CommandPipelineForwards: + """ + This class serves as a micro library for forward function substitution of Command models + under pipeline setting. + """ + + @staticmethod + def command_model_forward( + self: CohereModel, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ): + logger = logging.get_logger(__name__) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with pipeline parallelism. Setting `use_cache=False`..." + ) + use_cache = False + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + device = input_ids.device if input_ids is not None else inputs_embeds.device + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + else: + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + device = hidden_states.device + + past_seen_tokens = 0 + if use_cache: # kept for BC (cache positions) + if not isinstance(past_key_values, StaticCache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_seen_tokens = past_key_values.get_seq_length() + if cache_position is None: + if isinstance(past_key_values, StaticCache): + raise ValueError("cache_position is a required argument when using StaticCache.") + cache_position = torch.arange(past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=device) + + seq_length_with_past = seq_length + past_seen_tokens + + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + if use_cache: + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") + use_cache = False + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # embed positions, for the first stage, hidden_states is the input embeddings, + # for the other stages, hidden_states is the output of the previous stage + if shard_config.enable_flash_attention: + # in this case, attention_mask is a dict rather than a tensor + mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) + attention_mask = ColoAttention.prepare_attn_kwargs( + mask_shape, + hidden_states.dtype, + hidden_states.device, + q_padding_mask=attention_mask, + is_causal=True, + ) + else: + attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position) + + if self.gradient_checkpointing and self.training and use_cache: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + start_idx, end_idx = stage_index[0], stage_index[1] + num_ckpt_layers = 0 + if self.gradient_checkpointing and self.training: + num_ckpt_layers = end_idx - start_idx + # TODO: We can replace `gradient_checkpointing_enable` fn and initialize a gradient_checkpointing (List[bool]) for each layer + if shard_config.gradient_checkpoint_config is not None: + num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers( + stage=stage_manager.stage, + num_stages=stage_manager.num_stages, + num_layers=end_idx - start_idx, + model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0), + num_model_chunks=stage_manager.num_model_chunks, + ) + assert num_ckpt_layers <= end_idx - start_idx + + for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if idx - start_idx < num_ckpt_layers: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if stage_manager.is_last_stage(): + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if stage_manager.is_last_stage(): + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_cache, + all_hidden_states, + all_self_attns, + ] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + # always return dict for imediate stage + return {"hidden_states": hidden_states} + + @staticmethod + def command_for_causal_lm_forward( + self: CohereForCausalLM, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ): + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, CohereForCausalLM + + >>> model = CohereForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + logger = logging.get_logger(__name__) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = CommandPipelineForwards.command_model_forward( + self.model, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) + past_key_values = None + + if stage_manager.is_last_stage(): + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits * self.logit_scale + logits = logits.float() + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + if shard_config.enable_tensor_parallelism and shard_config.parallel_output: + new_vocab_size = logits.shape[-1] + shift_logits = shift_logits.view(-1, new_vocab_size) + loss = cross_entropy_1d( + shift_logits, + shift_labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=self.lm_head.out_features, + dtype=self.model.dtype, + ) + else: + shift_logits = shift_logits.view(-1, self.config.vocab_size) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get("hidden_states") + return {"hidden_states": hidden_states} + + +def get_command_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + if sp_mode is not None: + assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode" + assert (sp_size is not None) and ( + sp_group is not None + ), "Must specify sp_size and sp_group for sequence parallel" + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + bsz, q_len, _ = hidden_states.size() + # sp: modify sp_len when sequence parallel mode is ring + if sp_mode in ["split_gather", "ring"]: + q_len *= sp_size + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # sp: all-to-all comminucation when introducing sequence parallel + if sp_mode == "all_to_all": + query_states = all_to_all_comm(query_states, sp_group) + key_states = all_to_all_comm(key_states, sp_group) + value_states = all_to_all_comm(value_states, sp_group) + bsz, q_len, _ = query_states.size() + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if shard_config.enable_flash_attention: + assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." + attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) + else: + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + # sp: all-to-all comminucation when introducing sequence parallel + if sp_mode == "all_to_all": + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) + attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) + else: + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + return attn_output, attn_weights, past_key_value + + return forward + + +def get_command_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): + logger = logging.get_logger(__name__) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + past_seen_tokens = 0 + seq_len = inputs_embeds.shape[1] + if use_cache: # kept for BC (cache positions) + if not isinstance(past_key_values, StaticCache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_seen_tokens = past_key_values.get_seq_length() + if cache_position is None: + if isinstance(past_key_values, StaticCache): + raise ValueError("cache_position is a required argument when using StaticCache.") + cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_len, device=inputs_embeds.device) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # in this case, attention_mask is a dict rather than a tensor + if shard_config.enable_flash_attention: + mask_shape = (inputs_embeds.shape[0], 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len) + attention_mask = ColoAttention.prepare_attn_kwargs( + mask_shape, + inputs_embeds.dtype, + inputs_embeds.device, + q_padding_mask=attention_mask, + is_causal=True, + ) + else: + attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) + + if sp_mode in ["ring", "split_gather"]: + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) + elif sp_mode == "all_to_all": + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + if sp_mode == "ring" or sp_mode == "split_gather": + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) + elif sp_mode == "all_to_all": + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = ( + next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache + ) + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + return forward + + +def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): + from transformers import CohereForCausalLM + + def forward( + self: CohereForCausalLM, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, CohereForCausalLM + + >>> model = CohereForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + + logits = self.lm_head(hidden_states) + logits = logits * self.logit_scale + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + new_vocab_size = logits.shape[-1] + shift_logits = shift_logits.view(-1, new_vocab_size) + loss = cross_entropy_1d( + shift_logits, + shift_labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=self.lm_head.out_features, + dtype=self.model.dtype, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + return forward diff --git a/colossalai/shardformer/modeling/falcon.py b/colossalai/shardformer/modeling/falcon.py index a43bdf481..8181a68a0 100644 --- a/colossalai/shardformer/modeling/falcon.py +++ b/colossalai/shardformer/modeling/falcon.py @@ -291,18 +291,17 @@ class FalconPipelineForwards: if attention_mask_2d is None: attention_mask = alibi / math.sqrt(self.config.hidden_size // self.num_heads) else: + min_dtype = torch.finfo(alibi.dtype).min attention_mask = torch.masked_fill( alibi / math.sqrt(self.config.hidden_size // self.num_heads), attention_mask < -1, - torch.finfo(alibi.dtype).min, + min_dtype, ) # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 - if seq_length > 1: - attention_mask = AttentionMaskConverter._unmask_unattended( - attention_mask, attention_mask_2d, unmasked_value=0.0 - ) + if seq_length > 1 and attention_mask.device.type == "cuda": + attention_mask = AttentionMaskConverter._unmask_unattended(attention_mask, min_dtype=min_dtype) else: # PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case. attention_mask = _prepare_4d_causal_attention_mask( @@ -543,7 +542,10 @@ class FalconPipelineForwards: sequence_lengths = -1 else: if input_ids is not None: - sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(dim=-1) - 1).to(logits.device) + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) else: sequence_lengths = -1 logger.warning( diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index c49458dbd..aa75bab11 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -738,7 +738,10 @@ class GPT2PipelineForwards: sequence_lengths = -1 else: if input_ids is not None: - sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) else: sequence_lengths = -1 logger.warning_once( diff --git a/colossalai/shardformer/modeling/gptj.py b/colossalai/shardformer/modeling/gptj.py index 4f4cec8bc..facd2fcaf 100644 --- a/colossalai/shardformer/modeling/gptj.py +++ b/colossalai/shardformer/modeling/gptj.py @@ -32,6 +32,7 @@ def _get_attention_mask( hidden_states: torch.Tensor, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]], attention_mask: Optional[torch.FloatTensor], + use_flash_attention_2: bool = False, ) -> Optional[Union[torch.Tensor, dict]]: batch_size, seq_len = hidden_states.shape[:2] past_key_values_length = 0 @@ -47,7 +48,7 @@ def _get_attention_mask( attention_mask, is_causal=True, ) - elif attention_mask is not None: + elif use_flash_attention_2 and attention_mask is not None: if batch_size <= 0: raise ValueError("batch_size has to be defined and > 0") attention_mask = attention_mask.view(batch_size, -1) @@ -162,7 +163,9 @@ class GPTJPipelineForwards: output_shape = input_shape + (hidden_states.size(-1),) - attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask) + attention_mask = _get_attention_mask( + self, shard_config, hidden_states, past_key_values, attention_mask, self._use_flash_attention_2 + ) if self.gradient_checkpointing and self.training: if use_cache: @@ -419,7 +422,10 @@ class GPTJPipelineForwards: sequence_lengths = -1 else: if input_ids is not None: - sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) else: sequence_lengths = -1 logger.warning_once( @@ -712,7 +718,9 @@ def gptj_model_forward_for_flash_attention(shard_config: ShardConfig): hidden_states = self.drop(hidden_states) - attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask) + attention_mask = _get_attention_mask( + self, shard_config, hidden_states, past_key_values, attention_mask, self._use_flash_attention_2 + ) output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) @@ -886,7 +894,9 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig): hidden_states = self.drop(hidden_states) output_shape = input_shape + (hidden_states.size(-1),) - attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask) + attention_mask = _get_attention_mask( + self, shard_config, hidden_states, past_key_values, attention_mask, self._use_flash_attention_2 + ) if self.gradient_checkpointing and self.training: if use_cache: diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 01d10c8dc..bf5ce45a8 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -7,11 +7,7 @@ import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from transformers.cache_utils import Cache -from transformers.modeling_attn_mask_utils import ( - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, -) +from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -21,6 +17,7 @@ from transformers.models.llama.modeling_llama import ( LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, + StaticCache, apply_rotary_pos_emb, repeat_kv, ) @@ -55,6 +52,7 @@ class LlamaPipelineForwards: output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, @@ -67,6 +65,11 @@ class LlamaPipelineForwards: output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with pipeline parallelism. Setting `use_cache=False`..." + ) + use_cache = False return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -83,14 +86,24 @@ class LlamaPipelineForwards: device = input_ids.device if input_ids is not None else inputs_embeds.device if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds else: input_shape = hidden_states.shape[:-1] batch_size, seq_length = input_shape device = hidden_states.device - seq_length_with_past = seq_length - past_key_values_length = 0 + past_seen_tokens = 0 + if use_cache: # kept for BC (cache positions) + if not isinstance(past_key_values, StaticCache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_seen_tokens = past_key_values.get_seq_length() + if cache_position is None: + if isinstance(past_key_values, StaticCache): + raise ValueError("cache_position is a required argument when using StaticCache.") + cache_position = torch.arange(past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=device) + + seq_length_with_past = seq_length + past_seen_tokens # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: @@ -103,18 +116,8 @@ class LlamaPipelineForwards: logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") use_cache = False - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - if position_ids is None: - position_ids = torch.arange( - past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device, - ) - position_ids = position_ids.unsqueeze(0) + position_ids = cache_position.unsqueeze(0) # embed positions, for the first stage, hidden_states is the input embeddings, # for the other stages, hidden_states is the output of the previous stage @@ -129,28 +132,9 @@ class LlamaPipelineForwards: is_causal=True, ) else: - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._use_sdpa and not output_attentions: - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - hidden_states, - past_key_values_length, - ) + attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position) - if self.gradient_checkpointing and self.training: + if self.gradient_checkpointing and self.training and use_cache: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." @@ -190,6 +174,7 @@ class LlamaPipelineForwards: past_key_values, output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( @@ -199,6 +184,7 @@ class LlamaPipelineForwards: past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] @@ -249,6 +235,7 @@ class LlamaPipelineForwards: output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, @@ -306,6 +293,7 @@ class LlamaPipelineForwards: output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, stage_manager=stage_manager, hidden_states=hidden_states, stage_index=stage_index, @@ -368,6 +356,7 @@ class LlamaPipelineForwards: output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, @@ -401,6 +390,7 @@ class LlamaPipelineForwards: output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, stage_manager=stage_manager, hidden_states=hidden_states, stage_index=stage_index, @@ -470,36 +460,53 @@ class LlamaPipelineForwards: return {"hidden_states": hidden_states} -def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size): - from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb - - try: - from transformers.models.llama.modeling_llama import repeat_kv - except: - warnings.warn("using llamav1, llamav1 hasn't repeat_kv function") - +def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): def forward( - self: LlamaAttention, + self, hidden_states: torch.Tensor, - attention_mask: Optional[dict] = None, + attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + if sp_mode is not None: + assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode" + assert (sp_size is not None) and ( + sp_group is not None + ), "Must specify sp_size and sp_group for sequence parallel" if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) - bsz, q_len, _ = hidden_states.size() + bsz, q_len, _ = hidden_states.size() + # sp: modify sp_len when sequence parallel mode is ring if sp_mode in ["split_gather", "ring"]: q_len *= sp_size - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": @@ -520,39 +527,76 @@ def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size): "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." - attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + if shard_config.enable_flash_attention: + assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." + attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) + else: + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) - attn_output = self.o_proj(attn_output) + else: + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - return attn_output, None, past_key_value + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + return attn_output, attn_weights, past_key_value return forward -def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig): +def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): logger = logging.get_logger(__name__) - assert shard_config.enable_flash_attention, "Flash Attention is not enabled." def forward( - self: LlamaModel, + self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -562,119 +606,122 @@ def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False - seq_length_with_past = seq_length - past_key_values_length = 0 + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length + past_seen_tokens = 0 + seq_len = inputs_embeds.shape[1] + if use_cache: # kept for BC (cache positions) + if not isinstance(past_key_values, StaticCache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_seen_tokens = past_key_values.get_seq_length() + if cache_position is None: + if isinstance(past_key_values, StaticCache): + raise ValueError("cache_position is a required argument when using StaticCache.") + cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_len, device=inputs_embeds.device) if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device, + position_ids = cache_position.unsqueeze(0) + + # in this case, attention_mask is a dict rather than a tensor + if shard_config.enable_flash_attention: + mask_shape = (inputs_embeds.shape[0], 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len) + attention_mask = ColoAttention.prepare_attn_kwargs( + mask_shape, + inputs_embeds.dtype, + inputs_embeds.device, + q_padding_mask=attention_mask, + is_causal=True, ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: - position_ids = position_ids.view(-1, seq_length).long() + attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - # embed positions + if sp_mode in ["ring", "split_gather"]: + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) + elif sp_mode == "all_to_all": + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) hidden_states = inputs_embeds - # in this case, attention_mask is a dict rather than a tensor - mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) - attention_mask = ColoAttention.prepare_attn_kwargs( - mask_shape, - hidden_states.dtype, - hidden_states.device, - q_padding_mask=attention_mask, - is_causal=True, - ) - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None - for idx, decoder_layer in enumerate(self.layers): + for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) - - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, ) + else: layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) + if sp_mode == "ring" or sp_mode == "split_gather": + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) + elif sp_mode == "all_to_all": + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) + # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None + next_cache = None + if use_cache: + next_cache = ( + next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache + ) if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, @@ -700,6 +747,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -744,6 +792,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = outputs[0] @@ -786,266 +835,3 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): ) return forward - - -def get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group): - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - # sp: modify sp_len when sequence parallel mode is ring - if sp_mode in ["split_gather", "ring"]: - q_len *= sp_size - if self.config.pretraining_tp > 1: - key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 - ) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] - query_states = torch.cat(query_states, dim=-1) - - key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] - key_states = torch.cat(key_states, dim=-1) - - value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] - value_states = torch.cat(value_states, dim=-1) - - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # sp: all-to-all comminucation when introducing sequence parallel - if sp_mode == "all_to_all": - query_states = all_to_all_comm(query_states, sp_group) - key_states = all_to_all_comm(key_states, sp_group) - value_states = all_to_all_comm(value_states, sp_group) - bsz, q_len, _ = query_states.size() - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - # sp: all-to-all comminucation when introducing sequence parallel - if sp_mode == "all_to_all": - attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) - attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) - else: - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - if self.config.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) - attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) - else: - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - return attn_output, attn_weights, past_key_value - - return forward - - -def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group): - logger = logging.get_logger(__name__) - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - # modify past_key_values_length when using sequence parallel - past_key_values_length *= sp_size - seq_length_with_past = seq_length_with_past + past_key_values_length - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device, - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if sp_mode in ["ring", "split_gather"]: - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) - elif sp_mode == "all_to_all": - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) - - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), - dtype=torch.bool, - device=inputs_embeds.device, - ) - - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, attention_mask.shape, inputs_embeds, past_key_values_length - ) - - hidden_states = inputs_embeds - - if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - for idx, decoder_layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - attention_mask, - position_ids, - ) - - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - if sp_mode == "ring" or sp_mode == "split_gather": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) - elif sp_mode == "all_to_all": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - return forward diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py index 5f96ebe3d..310c2d8e2 100644 --- a/colossalai/shardformer/modeling/mistral.py +++ b/colossalai/shardformer/modeling/mistral.py @@ -4,7 +4,10 @@ from typing import List, Optional, Tuple, Union import torch from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.cache_utils import Cache, DynamicCache -from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -77,7 +80,7 @@ class MistralForwards: else: position_ids = position_ids.view(-1, seq_length).long() - if attention_mask is not None and self._use_flash_attention_2 and use_cache: + if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: raise ValueError( @@ -97,9 +100,18 @@ class MistralForwards: is_causal=True, ) else: - if self._use_flash_attention_2: + if self._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._attn_implementation == "sdpa" and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) else: # 4d mask is passed through the layers attention_mask = _prepare_4d_causal_attention_mask( @@ -462,7 +474,7 @@ def get_mistral_model_forward_for_flash_attn(shard_config: ShardConfig): if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if attention_mask is not None and self._use_flash_attention_2 and use_cache: + if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: raise ValueError( @@ -481,9 +493,18 @@ def get_mistral_model_forward_for_flash_attn(shard_config: ShardConfig): is_causal=True, ) else: - if self._use_flash_attention_2: + if self._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._attn_implementation == "sdpa" and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) else: # 4d mask is passed through the layers attention_mask = _prepare_4d_causal_attention_mask( diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py index 6d7df963a..cf925983b 100644 --- a/colossalai/shardformer/modeling/whisper.py +++ b/colossalai/shardformer/modeling/whisper.py @@ -17,6 +17,7 @@ from transformers.modeling_outputs import ( SequenceClassifierOutput, ) from transformers.models.whisper.modeling_whisper import ( + _HIDDEN_STATES_START_POSITION, WhisperDecoder, WhisperEncoder, WhisperForAudioClassification, @@ -166,6 +167,7 @@ def get_whisper_decoder_forward_for_flash_attention(shard_config: ShardConfig): cross_attn_head_mask=None, past_key_values=None, inputs_embeds=None, + position_ids=None, use_cache=None, output_attentions=None, output_hidden_states=None, @@ -199,9 +201,13 @@ def get_whisper_decoder_forward_for_flash_attention(shard_config: ShardConfig): # embed positions if input_ids is not None: - positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) + positions = self.embed_positions( + input_ids, past_key_values_length=past_key_values_length, position_ids=position_ids + ) else: - positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length) + positions = self.embed_positions( + inputs_embeds, past_key_values_length=past_key_values_length, position_ids=position_ids + ) hidden_states = inputs_embeds + positions hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -599,6 +605,7 @@ class WhisperPipelineForwards: cross_attn_head_mask=None, past_key_values=None, inputs_embeds=None, + position_ids=None, use_cache=None, output_attentions=None, output_hidden_states=None, @@ -716,9 +723,13 @@ class WhisperPipelineForwards: # embed positions if input_ids is not None: - positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) + positions = self.embed_positions( + input_ids, past_key_values_length=past_key_values_length, position_ids=position_ids + ) else: - positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length) + positions = self.embed_positions( + inputs_embeds, past_key_values_length=past_key_values_length, position_ids=position_ids + ) hidden_states = inputs_embeds + positions hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -841,6 +852,7 @@ class WhisperPipelineForwards: encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, + decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, @@ -944,6 +956,7 @@ class WhisperPipelineForwards: cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, + position_ids=decoder_position_ids, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -986,6 +999,7 @@ class WhisperPipelineForwards: encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, + decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -1048,6 +1062,7 @@ class WhisperPipelineForwards: cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, decoder_inputs_embeds=decoder_inputs_embeds, + decoder_position_ids=decoder_position_ids, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -1118,6 +1133,12 @@ class WhisperPipelineForwards: output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + + if self.config.use_weighted_layer_sum: + output_hidden_states = True + elif output_hidden_states is None: + output_hidden_states = self.config.output_hidden_states + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # audio_classification only holds encoder @@ -1138,7 +1159,8 @@ class WhisperPipelineForwards: return encoder_outputs if self.config.use_weighted_layer_sum: - hidden_states = torch.stack(encoder_outputs, dim=1) + hidden_states = encoder_outputs[_HIDDEN_STATES_START_POSITION] + hidden_states = torch.stack(hidden_states, dim=1) norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) else: diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index 69df021b0..008dead6b 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -192,6 +192,13 @@ _POLICY_LIST = { "transformers.models.qwen2.modeling_qwen2.Qwen2ForSequenceClassification": PolicyLocation( file_name="qwen2", class_name="Qwen2ForSequenceClassificationPolicy" ), + # Command-R + "transformers.models.cohere.modeling_cohere.CohereModel": PolicyLocation( + file_name="command", class_name="CommandModelPolicy" + ), + "transformers.models.cohere.modeling_cohere.CohereForCausalLM": PolicyLocation( + file_name="command", class_name="CommandForCausalLMPolicy" + ), } diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 0c04f7d38..c11ed99ac 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -67,7 +67,7 @@ class BertPolicy(Policy): else: norm_cls = col_nn.LayerNorm - sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None + sp_mode = self.shard_config.sequence_parallelism_mode or None assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for Bert" if sp_mode == "ring": warnings.warn( diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 724a6b77c..20a75cf90 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -50,7 +50,7 @@ class BloomPolicy(Policy): else: norm_cls = col_nn.LayerNorm - sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None + sp_mode = self.shard_config.sequence_parallelism_mode or None assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for BLOOM" if sp_mode == "ring": warnings.warn( diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index 4baf89f6a..01aa77e57 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -57,7 +57,7 @@ class ChatGLMPolicy(Policy): else: norm_cls = col_nn.LayerNorm - sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None + sp_mode = self.shard_config.sequence_parallelism_mode or None assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for ChatGLM2" if sp_mode == "ring": warnings.warn( diff --git a/colossalai/shardformer/policies/command.py b/colossalai/shardformer/policies/command.py new file mode 100644 index 000000000..902baf2e1 --- /dev/null +++ b/colossalai/shardformer/policies/command.py @@ -0,0 +1,369 @@ +import warnings +from functools import partial +from typing import Callable, Dict, List, Union + +import torch.nn as nn +from torch import Tensor +from torch.nn import Module + +from colossalai.shardformer.layer import ( + FusedLayerNorm, + LayerNorm, + Linear1D_Col, + Linear1D_Row, + PaddingEmbedding, + PaddingLMHead, + VocabParallelEmbedding1D, + VocabParallelLMHead1D, +) + +from ..modeling.command import ( + CommandPipelineForwards, + get_command_flash_attention_forward, + get_command_flash_attention_model_forward, + get_lm_forward_with_dist_cross_entropy, +) +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +__all__ = ["CommandPolicy", "CommandForCausalLMPolicy"] + + +class CommandPolicy(Policy): + def config_sanity_check(self): + pass + + def preprocess(self): + self.tie_weight = self.tie_weight_check() + self.origin_attn_implement = self.model.config._attn_implementation + return self.model + + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + from transformers.models.cohere.modeling_cohere import ( + CohereAttention, + CohereDecoderLayer, + CohereFlashAttention2, + CohereModel, + CohereSdpaAttention, + ) + + ATTN_IMPLEMENTATION = { + "eager": CohereAttention, + "flash_attention_2": CohereFlashAttention2, + "sdpa": CohereSdpaAttention, + } + policy = {} + + attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = PaddingEmbedding + + if self.shard_config.enable_fused_normalization: + norm_cls = FusedLayerNorm + else: + norm_cls = LayerNorm + + if self.pipeline_stage_manager is not None: + self.shard_config.enable_sequence_parallelism = False + self.shard_config.enable_sequence_overlap = False + self.shard_config.sequence_parallelism_mode = None + warnings.warn( + f"For Command, sequence parallelism is currently not compatible with pipeline parallelism, set to be False" + ) + sp_mode = self.shard_config.sequence_parallelism_mode or None + sp_size = self.shard_config.sequence_parallel_size or None + sp_group = self.shard_config.sequence_parallel_process_group or None + sp_partial_derived = sp_mode in ["split_gather", "ring"] + + if sp_mode == "all_to_all": + decoder_attribute_replacement = { + "num_heads": self.model.config.num_attention_heads // sp_size, + } + if getattr(self.model.config, "num_key_value_heads", False): + decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size + + policy[attn_cls] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + ) + if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism: + self.append_or_create_method_replacement( + description={ + "forward": get_command_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group), + }, + policy=policy, + target_key=attn_cls, + ) + if self.pipeline_stage_manager is None: + self.append_or_create_method_replacement( + description={ + "forward": get_command_flash_attention_model_forward( + self.shard_config, + sp_mode=sp_mode, + sp_size=sp_size, + sp_group=sp_group, + ), + }, + policy=policy, + target_key=CohereModel, + ) + + if self.shard_config.enable_tensor_parallelism: + assert ( + self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of attention heads must be divisible by tensor parallel size." + if hasattr(self.model.config, "num_key_value_heads"): + assert ( + self.model.config.num_key_value_heads >= self.shard_config.tensor_parallel_size + and self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of key_value heads must be divisible by, and must not be less than tensor parallel size." + decoder_attribute_replacement = { + "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + } + if getattr(self.model.config, "num_key_value_heads", False): + decoder_attribute_replacement["self_attn.num_key_value_heads"] = ( + self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size + ) + + policy[CohereDecoderLayer] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=Linear1D_Col, + kwargs=dict(seq_parallel_mode=sp_mode), + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=Linear1D_Col, + kwargs=dict(seq_parallel_mode=sp_mode), + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=Linear1D_Col, + kwargs=dict(seq_parallel_mode=sp_mode), + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=Linear1D_Row, + kwargs=dict(seq_parallel_mode=sp_mode), + ), + SubModuleReplacementDescription( + suffix="mlp.gate_proj", + target_module=Linear1D_Col, + kwargs=dict(seq_parallel_mode=sp_mode), + ), + SubModuleReplacementDescription( + suffix="mlp.up_proj", + target_module=Linear1D_Col, + kwargs=dict(seq_parallel_mode=sp_mode), + ), + SubModuleReplacementDescription( + suffix="mlp.down_proj", + target_module=Linear1D_Row, + kwargs=dict(seq_parallel_mode=sp_mode), + ), + ], + ) + + if embedding_cls is not None: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ), + policy=policy, + target_key=CohereModel, + ) + + # optimization configuration + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="input_layernorm", + target_module=norm_cls, + kwargs={"sp_partial_derived": sp_partial_derived}, + ), + ], + policy=policy, + target_key=CohereDecoderLayer, + ) + + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="norm", + target_module=norm_cls, + kwargs={"sp_partial_derived": sp_partial_derived}, + ), + policy=policy, + target_key=CohereModel, + ) + + return policy + + def postprocess(self): + return self.model + + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if self.pipeline_stage_manager is None: + return + + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == "CohereModel": + module = self.model + else: + module = self.model.model + + if stage_manager.is_interleave: + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage) + method_replacement = { + "forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config) + } + + else: + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_index = stage_manager.get_stage_index(layers_per_stage) + method_replacement = { + "forward": partial( + new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config + ) + } + + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None + + if self.model.__class__.__name__ == "CohereModel": + module = self.model + else: + module = self.model.model + stage_manager = self.pipeline_stage_manager + + held_layers = [] + if stage_manager.is_interleave: + assert stage_manager.num_model_chunks is not None + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_indices = stage_manager.get_stage_index(layers_per_stage) + if stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(module.embed_tokens) + for start_idx, end_idx in stage_indices: + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(module.norm) + + else: + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + if stage_manager.is_first_stage(): + held_layers.append(module.embed_tokens) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.norm) + + return held_layers + + +class CommandModelPolicy(CommandPolicy): + def module_policy(self): + policy = super().module_policy() + from transformers.models.cohere.modeling_cohere import CohereModel + + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=CohereModel, new_forward=CommandPipelineForwards.command_model_forward, policy=policy + ) + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + held_layers = super().get_held_layers() + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in command model""" + return [] + + +class CommandForCausalLMPolicy(CommandPolicy): + def module_policy(self): + from transformers import CohereForCausalLM + + policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + # add a new item for casual lm + new_item = { + CohereForCausalLM: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=VocabParallelLMHead1D, + kwargs={ + "gather_output": not self.shard_config.parallel_output, + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + }, + ) + ], + ) + } + if self.shard_config.parallel_output: + new_item[CohereForCausalLM].method_replacement = { + "forward": get_lm_forward_with_dist_cross_entropy(self.shard_config) + } + else: + new_item = { + CohereForCausalLM: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=PaddingLMHead, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ) + ], + ) + } + policy.update(new_item) + + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=CohereForCausalLM, + new_forward=CommandPipelineForwards.command_for_causal_lm_forward, + policy=policy, + ) + + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(self.model.lm_head) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + command_model = self.model.model + if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: + if ( + id(command_model.embed_tokens.weight) == id(self.model.lm_head.weight) + and self.pipeline_stage_manager.num_stages > 1 + ): + # tie weights + return [ + { + 0: command_model.embed_tokens.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, + } + ] + return [] diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 281ea88c2..cfe20000a 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -65,7 +65,7 @@ class GPT2Policy(Policy): else: norm_cls = col_nn.LayerNorm - sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None + sp_mode = self.shard_config.sequence_parallelism_mode or None assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for GPT2" if sp_mode == "ring": warnings.warn( diff --git a/colossalai/shardformer/policies/gptj.py b/colossalai/shardformer/policies/gptj.py index 3315eb1e9..c394d911e 100644 --- a/colossalai/shardformer/policies/gptj.py +++ b/colossalai/shardformer/policies/gptj.py @@ -34,15 +34,11 @@ class GPTJPolicy(Policy): return self.model def module_policy(self): - from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJBlock, GPTJModel - - ATTN_IMPLEMENTATION = { - "eager": GPTJAttention, - } + from transformers.models.gptj.modeling_gptj import GPTJ_ATTENTION_CLASSES, GPTJBlock, GPTJModel policy = {} - attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] + attn_cls = GPTJ_ATTENTION_CLASSES[self.origin_attn_implement] embedding_cls = None if self.shard_config.enable_tensor_parallelism: diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index a9c982231..85ec6717d 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -20,9 +20,7 @@ from colossalai.shardformer.layer import ( from ..modeling.llama import ( LlamaPipelineForwards, get_llama_flash_attention_forward, - get_llama_model_forward_for_flash_attn, - get_llama_seq_parallel_attention_forward, - get_llama_seq_parallel_model_forward, + get_llama_flash_attention_model_forward, get_lm_forward_with_dist_cross_entropy, ) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -75,40 +73,12 @@ class LlamaPolicy(Policy): warnings.warn( f"For llama, sequence parallelism is currently not compatible with pipeline parallelism, set to be False" ) - sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None - sp_size = self.shard_config.sequence_parallel_size if self.shard_config.enable_sequence_parallelism else None - sp_group = ( - self.shard_config.sequence_parallel_process_group if self.shard_config.enable_sequence_parallelism else None - ) + sp_mode = self.shard_config.sequence_parallelism_mode or None + sp_size = self.shard_config.sequence_parallel_size or None + sp_group = self.shard_config.sequence_parallel_process_group or None sp_partial_derived = sp_mode in ["split_gather", "ring"] - use_flash_attention = self.shard_config.enable_flash_attention - # Currently sp cannot to be used with flashattention - if sp_mode in ["split_gather", "ring", "all_to_all"]: - if use_flash_attention: - warnings.warn( - f"Sequence parallelism mode {sp_mode} need to be used with FlashAttention, will disable FlashAttention automatically." - ) - use_flash_attention = False - - if sp_mode in ["split_gather", "ring"]: - self.append_or_create_method_replacement( - description={ - "forward": get_llama_seq_parallel_model_forward( - sp_mode=sp_mode, sp_size=sp_size, sp_group=sp_group - ), - }, - policy=policy, - target_key=LlamaModel, - ) - self.append_or_create_method_replacement( - description={ - "forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group), - }, - policy=policy, - target_key=attn_cls, - ) - elif sp_mode == "all_to_all": + if sp_mode == "all_to_all": decoder_attribute_replacement = { "num_heads": self.model.config.num_attention_heads // sp_size, } @@ -118,24 +88,27 @@ class LlamaPolicy(Policy): policy[attn_cls] = ModulePolicyDescription( attribute_replacement=decoder_attribute_replacement, ) + if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism: self.append_or_create_method_replacement( description={ - "forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group), + "forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group), }, policy=policy, target_key=attn_cls, ) - self.append_or_create_method_replacement( - description={ - "forward": get_llama_seq_parallel_model_forward( - sp_mode=sp_mode, - sp_size=sp_size, - sp_group=sp_group, - ), - }, - policy=policy, - target_key=LlamaModel, - ) + if self.pipeline_stage_manager is None: + self.append_or_create_method_replacement( + description={ + "forward": get_llama_flash_attention_model_forward( + self.shard_config, + sp_mode=sp_mode, + sp_size=sp_size, + sp_group=sp_group, + ), + }, + policy=policy, + target_key=LlamaModel, + ) if self.shard_config.enable_tensor_parallelism: assert ( @@ -235,25 +208,6 @@ class LlamaPolicy(Policy): target_key=LlamaModel, ) - # use flash attention - if use_flash_attention: - self.append_or_create_method_replacement( - description={ - "forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_group, sp_size), - }, - policy=policy, - target_key=attn_cls, - ) - if self.pipeline_stage_manager is None: - # replace llama model forward method - self.append_or_create_method_replacement( - description={ - "forward": get_llama_model_forward_for_flash_attn(self.shard_config), - }, - policy=policy, - target_key=LlamaModel, - ) - return policy def postprocess(self): diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index 621982f29..c5a0277a5 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -42,11 +42,13 @@ class MistralPolicy(Policy): MistralDecoderLayer, MistralFlashAttention2, MistralModel, + MistralSdpaAttention, ) ATTN_IMPLEMENTATION = { "eager": MistralAttention, "flash_attention_2": MistralFlashAttention2, + "sdpa": MistralSdpaAttention, } policy = {} diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py index 45066ca89..3a5f0a5aa 100644 --- a/colossalai/zero/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -25,6 +25,7 @@ class ChunkManager: chunk_configuration, init_device: Optional[torch.device] = None, reuse_fp16_chunk: bool = True, + max_prefetch: int = 0, ) -> None: self.device = init_device or get_accelerator().get_current_device() self.dp_degree_chunk_size_dict: Dict[int, int] = dict() @@ -42,6 +43,7 @@ class ChunkManager: # Whether model is accumulating gradients, self.accumulating_grads = False self.overflow_counter = torch.tensor([0], dtype=torch.int, device=get_accelerator().get_current_device()) + self._prefetch_stream = get_accelerator().Stream() if max_prefetch else None def register_tensor( self, diff --git a/colossalai/zero/gemini/chunk/utils.py b/colossalai/zero/gemini/chunk/utils.py index 049c5c102..884d1306e 100644 --- a/colossalai/zero/gemini/chunk/utils.py +++ b/colossalai/zero/gemini/chunk/utils.py @@ -21,6 +21,7 @@ def init_chunk_manager( hidden_dim: Optional[int] = None, reuse_fp16_chunk: bool = True, verbose: bool = False, + max_prefetch: int = 0, **kwargs, ) -> ChunkManager: if hidden_dim: @@ -51,9 +52,5 @@ def init_chunk_manager( ) dist.barrier() - chunk_manager = ChunkManager( - config_dict, - init_device, - reuse_fp16_chunk=reuse_fp16_chunk, - ) + chunk_manager = ChunkManager(config_dict, init_device, reuse_fp16_chunk=reuse_fp16_chunk, max_prefetch=max_prefetch) return chunk_manager diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 6f6064000..ebdde83b4 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -104,9 +104,7 @@ class GeminiDDP(ModelWrapper): self.enable_gradient_accumulation = enable_gradient_accumulation if chunk_config_dict is not None: self.chunk_manager = ChunkManager( - chunk_config_dict, - chunk_init_device, - reuse_fp16_chunk=reuse_fp16_chunk, + chunk_config_dict, chunk_init_device, reuse_fp16_chunk=reuse_fp16_chunk, max_prefetch=max_prefetch ) else: # some ugly hotfix for the compatibility with Lightning @@ -122,6 +120,7 @@ class GeminiDDP(ModelWrapper): process_group=zero_group, reuse_fp16_chunk=reuse_fp16_chunk, verbose=verbose, + max_prefetch=max_prefetch, ) self.gemini_manager = GeminiManager( placement_policy, @@ -451,6 +450,7 @@ class GeminiDDP(ModelWrapper): chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True) if not (master_weights) or (enable_gradient_accumulation): chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True) + return empty_grad def zero_grad(self, set_to_none: bool = False) -> None: self.module.zero_grad(set_to_none=True) diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index 736238a09..bf5faa0fe 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -5,6 +5,7 @@ from typing import List import torch +from colossalai.accelerator import get_accelerator from colossalai.tensor.param_op_hook import ColoParamOpHook from colossalai.utils import is_ddp_ignored from colossalai.zero.gemini import TensorState @@ -54,10 +55,20 @@ class GeminiZeROHook(ColoParamOpHook): ) # prefetch - for chunk in chunks_fetch_async: - maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True) - if maybe_work is not None: - self._gemini_manager.add_work(chunk, maybe_work) + if self._gemini_manager.chunk_manager._prefetch_stream is not None: + # This is when prefetch happens the first time and there is no dist.Work to sync, + # there is possibility that the optimizer haven't finish computation on default stream, + # thus we might prefetch outdated chunks there. + # + # Other than that, self._gemini_manager.wait_chunks will have synced with default stream + # by calling dist.Work.wait() and this line makes no diff. + self._gemini_manager.chunk_manager._prefetch_stream.wait_stream(torch.cuda.current_stream()) + + with get_accelerator().stream(self._gemini_manager.chunk_manager._prefetch_stream): + for chunk in chunks_fetch_async: + maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True) + if maybe_work is not None: + self._gemini_manager.add_work(chunk, maybe_work) # record cuda model data of the current OP, including memory for prefetched chunks self._gemini_manager.record_model_data_volume() diff --git a/docs/README-zh-Hans.md b/docs/README-zh-Hans.md index 5878abbaa..dc97b461a 100644 --- a/docs/README-zh-Hans.md +++ b/docs/README-zh-Hans.md @@ -24,6 +24,7 @@ ## 新闻 +* [2024/06] [Open-Sora Continues Open Source: Generate Any 16-Second 720p HD Video with One Click, Model Weights Ready to Use](https://hpc-ai.com/blog/open-sora-from-hpc-ai-tech-team-continues-open-source-generate-any-16-second-720p-hd-video-with-one-click-model-weights-ready-to-use) * [2024/05] [Large AI Models Inference Speed Doubled, Colossal-Inference Open Source Release](https://hpc-ai.com/blog/colossal-inference) * [2024/04] [Open-Sora Unveils Major Upgrade: Embracing Open Source with Single-Shot 16-Second Video Generation and 720p Resolution](https://hpc-ai.com/blog/open-soras-comprehensive-upgrade-unveiled-embracing-16-second-video-generation-and-720p-resolution-in-open-source) * [2024/04] [Most cost-effective solutions for inference, fine-tuning and pretraining, tailored to LLaMA3 series](https://hpc-ai.com/blog/most-cost-effective-solutions-for-inference-fine-tuning-and-pretraining-tailored-to-llama3-series) @@ -31,10 +32,6 @@ * [2024/03] [Open-Sora: Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models](https://hpc-ai.com/blog/open-sora-v1.0) * [2024/03] [Open-Sora:Sora Replication Solution with 46% Cost Reduction, Sequence Expansion to Nearly a Million](https://hpc-ai.com/blog/open-sora) * [2024/01] [Inference Performance Improved by 46%, Open Source Solution Breaks the Length Limit of LLM for Multi-Round Conversations](https://hpc-ai.com/blog/Colossal-AI-SwiftInfer) -* [2024/01] [Construct Refined 13B Private Model With Just $5000 USD, Upgraded Colossal-AI Llama-2 Open Source](https://hpc-ai.com/blog/colossal-llama-2-13b) -* [2023/11] [Enhanced MoE Parallelism, Open-source MoE Model Training Can Be 9 Times More Efficient](https://www.hpc-ai.tech/blog/enhanced-moe-parallelism-open-source-moe-model-training-can-be-9-times-more-efficient) -* [2023/09] [One Half-Day of Training Using a Few Hundred Dollars Yields Similar Results to Mainstream Large Models, Open-Source and Commercial-Free Domain-Specific LLM Solution](https://www.hpc-ai.tech/blog/one-half-day-of-training-using-a-few-hundred-dollars-yields-similar-results-to-mainstream-large-models-open-source-and-commercial-free-domain-specific-llm-solution) -* [2023/09] [70 Billion Parameter LLaMA2 Model Training Accelerated by 195%](https://www.hpc-ai.tech/blog/70b-llama2-training) * [2023/07] [HPC-AI Tech Raises 22 Million USD in Series A Funding](https://www.hpc-ai.tech/blog/hpc-ai-tech-raises-22-million-usd-in-series-a-funding-to-fuel-team-expansion-and-business-growth) ## 目录 @@ -127,13 +124,13 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的 [Open-Sora](https://github.com/hpcaitech/Open-Sora):全面开源类Sora模型参数和所有训练细节 [[代码]](https://github.com/hpcaitech/Open-Sora) -[[博客]](https://hpc-ai.com/blog/open-soras-comprehensive-upgrade-unveiled-embracing-16-second-video-generation-and-720p-resolution-in-open-source) -[[模型权重]](https://huggingface.co/hpcai-tech/Open-Sora) +[[博客]](https://hpc-ai.com/blog/open-sora-from-hpc-ai-tech-team-continues-open-source-generate-any-16-second-720p-hd-video-with-one-click-model-weights-ready-to-use) +[[模型权重]](https://github.com/hpcaitech/Open-Sora?tab=readme-ov-file#model-weights) [[演示样例]](https://github.com/hpcaitech/Open-Sora?tab=readme-ov-file#-latest-demo) diff --git a/docs/sidebars.json b/docs/sidebars.json index 123211db5..754600627 100644 --- a/docs/sidebars.json +++ b/docs/sidebars.json @@ -56,6 +56,7 @@ "features/pipeline_parallel", "features/nvme_offload", "features/lazy_init", + "features/distributed_optimizers", "features/cluster_utils" ] }, diff --git a/docs/source/en/features/distributed_optimizers.md b/docs/source/en/features/distributed_optimizers.md index bdd7a3f40..f95b23304 100644 --- a/docs/source/en/features/distributed_optimizers.md +++ b/docs/source/en/features/distributed_optimizers.md @@ -4,9 +4,9 @@ Author: [Wenxuan Tan](https://github.com/Edenzzzz), [Junwen Duan](https://github **Related Paper** - [Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235) -- [CAME: Confidence-guided Adaptive Memory Efficient Optimization] (https://arxiv.org/abs/2307.02047) -- [GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection] (https://arxiv.org/abs/2403.03507) -- [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes] (https://arxiv.org/pdf/1904.00962) +- [CAME: Confidence-guided Adaptive Memory Efficient Optimization](https://arxiv.org/abs/2307.02047) +- [GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection](https://arxiv.org/abs/2403.03507) +- [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes](https://arxiv.org/pdf/1904.00962) ## Introduction Apart from the widely adopted Adam and SGD, many modern optimizers require layer-wise statistics to update parameters, and thus aren't directly applicable to settings where model layers are sharded across multiple devices. We provide optimized distributed implementations with minimal extra communications, and seamless integrations with Tensor Parallel, DDP and ZeRO plugins, which automatically uses distributed optimizers with 0 code change. @@ -14,12 +14,6 @@ Apart from the widely adopted Adam and SGD, many modern optimizers require layer ## Optimizers Adafactor is a first-order Adam variant using Non-negative Matrix Factorization(NMF) to reduce memory footprint. CAME improves by introducting a confidence matrix to correct NMF. GaLore further reduces memory by projecting gradients into a low-rank space and 8-bit block-wise quantization. Lamb allows huge batch sizes without lossing accuracy via layer-wise adaptive update bounded by the inverse of its Lipschiz constant. -## API Reference - -{{ autodoc:colossalai.nn.optimizer.distributed_adafactor.DistributedAdaFactor }} -{{ autodoc:colossalai.nn.optimizer.distributed_lamb.DistributedLamb }} -{{ autodoc:colossalai.nn.optimizer.distributed_galore.DistGaloreAwamW }} -{{ autodoc:colossalai.nn.optimizer.distributed_came.DistributedCAME }} ## Hands-On Practice We now demonstrate how to use Distributed Adafactor with booster API combining Tensor Parallel and ZeRO 2 with 4 GPUs. **Note that even if you're not aware of distributed optimizers, the plugins automatically casts yours to the distributed version for convenience.** @@ -140,3 +134,10 @@ optim = DistGaloreAwamW( + +## API Reference + +{{ autodoc:colossalai.nn.optimizer.distributed_adafactor.DistributedAdaFactor }} +{{ autodoc:colossalai.nn.optimizer.distributed_lamb.DistributedLamb }} +{{ autodoc:colossalai.nn.optimizer.distributed_galore.DistGaloreAwamW }} +{{ autodoc:colossalai.nn.optimizer.distributed_came.DistributedCAME }} diff --git a/docs/source/zh-Hans/features/distributed_optimizers.md b/docs/source/zh-Hans/features/distributed_optimizers.md index 36dbdf948..7a7068077 100644 --- a/docs/source/zh-Hans/features/distributed_optimizers.md +++ b/docs/source/zh-Hans/features/distributed_optimizers.md @@ -4,21 +4,15 @@ Author: Wenxuan Tan, Junwen Duan, Renjie Mao **相关论文** - [Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235) -- [CAME: Confidence-guided Adaptive Memory Efficient Optimization] (https://arxiv.org/abs/2307.02047) -- [GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection] (https://arxiv.org/abs/2403.03507) -- [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes] (https://arxiv.org/pdf/1904.00962) +- [CAME: Confidence-guided Adaptive Memory Efficient Optimization](https://arxiv.org/abs/2307.02047) +- [GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection](https://arxiv.org/abs/2403.03507) +- [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes](https://arxiv.org/pdf/1904.00962) ## 介绍 除了广泛采用的Adam和SGD外,许多现代优化器需要逐层统计信息以有效更新参数,因此无法直接应用于模型层在多个设备上分片的并行设置。我们以提供了优化的分布式实现,,并且通过plugin与Tensor Parallel、DDP和ZeRO无缝集成。 ## 优化器 Adafactor 是一种首次采用非负矩阵分解(NMF)的 Adam 变体,用于减少内存占用。CAME 通过引入一个置信度矩阵来改进 NMF 的效果。GaLore 通过将梯度投影到低秩空间,并使用 8 位块状量化进一步减少内存占用。Lamb 允许使用巨大的批量大小而不失准确性,通过按其 Lipschitz 常数的倒数界定的逐层自适应更新实现 -## API 参考 - -{{ autodoc:colossalai.nn.optimizer.distributed_adafactor.DistributedAdaFactor }} -{{ autodoc:colossalai.nn.optimizer.distributed_lamb.DistributedLamb }} -{{ autodoc:colossalai.nn.optimizer.distributed_galore.DistGaloreAwamW }} -{{ autodoc:colossalai.nn.optimizer.distributed_came.DistributedCAME }} ## 使用 现在我们展示如何使用分布式 Adafactor 与 booster API 结合 Tensor Parallel 和 ZeRO 2。即使您不使用distributed optimizer,plugin 也会自动将optimizer转换为分布式版本以方便使用。 @@ -137,3 +131,10 @@ optim = DistGaloreAwamW( + +## API 参考 + +{{ autodoc:colossalai.nn.optimizer.distributed_adafactor.DistributedAdaFactor }} +{{ autodoc:colossalai.nn.optimizer.distributed_lamb.DistributedLamb }} +{{ autodoc:colossalai.nn.optimizer.distributed_galore.DistGaloreAwamW }} +{{ autodoc:colossalai.nn.optimizer.distributed_came.DistributedCAME }} diff --git a/examples/inference/llama/README.md b/examples/inference/llama/README.md index cde81a41d..dae7f771c 100644 --- a/examples/inference/llama/README.md +++ b/examples/inference/llama/README.md @@ -43,5 +43,8 @@ colossalai run --nproc_per_node 2 llama_generation.py -m PATH_MODEL --drafter_mo If you want to try the GLIDE model (glide-vicuna7b) as the drafter model with vicuna-7B, you could provide the GLIDE model path or model card as drafter model and enable the feature by ```python +from colossalai.inference.modeling.models.glide_llama import GlideLlamaForCausalLM +drafter_model = GlideLlamaForCausalLM.from_pretrained(drafter_model_path_or_name) +... engine.enable_spec_dec(drafter_model, use_glide_drafter=True) ``` diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index 8d4dae314..f6c975305 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -72,6 +72,7 @@ def main(): parser.add_argument("--offload_optim_frac", type=float, default=0.0, help="Offload optim fraction. Only for gemini") parser.add_argument("--offload_param_frac", type=float, default=0.0, help="Offload param fraction. Only for gemini") parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size") + parser.add_argument("--sp", type=int, default=1, help="Sequence parallel size") parser.add_argument("--extra_dp", type=int, default=1, help="Extra data parallel size, used for Gemini") parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size") parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel") @@ -174,6 +175,8 @@ def main(): tp_size=args.tp, pp_size=args.pp, zero_stage=args.zero, + sp_size=args.sp, + enable_sequence_parallelism=args.sp > 1, enable_fused_normalization=torch.cuda.is_available(), enable_flash_attention=args.xformers, microbatch_size=args.mbs, diff --git a/requirements/requirements.txt b/requirements/requirements.txt index fa88501ef..27bbc3769 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -16,7 +16,7 @@ ray sentencepiece google protobuf -transformers>=4.36.2,<4.40.0 +transformers==4.39.3 peft>=0.7.1 bitsandbytes>=0.39.0 rpyc==6.0.0 diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py index d5bddcff0..05c17f562 100644 --- a/tests/kit/model_zoo/transformers/__init__.py +++ b/tests/kit/model_zoo/transformers/__init__.py @@ -22,3 +22,9 @@ try: from .qwen2 import * except ImportError: print("This version of transformers doesn't support qwen2.") + + +try: + from .command import * +except ImportError: + print("This version of transformers doesn't support Command-R.") diff --git a/tests/kit/model_zoo/transformers/chatglm2.py b/tests/kit/model_zoo/transformers/chatglm2.py index f443553bb..9a7cf34c1 100644 --- a/tests/kit/model_zoo/transformers/chatglm2.py +++ b/tests/kit/model_zoo/transformers/chatglm2.py @@ -33,22 +33,6 @@ loss_fn_for_chatglm_model = lambda x: torch.nn.functional.mse_loss( ) loss_fn = lambda x: x["loss"] -config = AutoConfig.from_pretrained( - "THUDM/chatglm2-6b", - trust_remote_code=True, - num_layers=2, - padded_vocab_size=65024, - hidden_size=64, - ffn_hidden_size=214, - num_attention_heads=8, - kv_channels=16, - rmsnorm=True, - original_rope=True, - use_cache=True, - multi_query_attention=False, - torch_dtype=torch.float32, -) - infer_config = AutoConfig.from_pretrained( "THUDM/chatglm2-6b", @@ -68,6 +52,21 @@ infer_config = AutoConfig.from_pretrained( def init_chatglm(): + config = AutoConfig.from_pretrained( + "THUDM/chatglm2-6b", + trust_remote_code=True, + num_layers=2, + padded_vocab_size=65024, + hidden_size=64, + ffn_hidden_size=214, + num_attention_heads=8, + kv_channels=16, + rmsnorm=True, + original_rope=True, + use_cache=True, + multi_query_attention=False, + torch_dtype=torch.float32, + ) model = AutoModelForCausalLM.from_config(config, empty_init=False, trust_remote_code=True) for m in model.modules(): if m.__class__.__name__ == "RMSNorm": diff --git a/tests/kit/model_zoo/transformers/command.py b/tests/kit/model_zoo/transformers/command.py new file mode 100644 index 000000000..a8b8842c5 --- /dev/null +++ b/tests/kit/model_zoo/transformers/command.py @@ -0,0 +1,79 @@ +import torch +import transformers + +from ..registry import ModelAttribute, model_zoo + +try: + from transformers import CohereConfig + + HAS_COMMAND = True +except ImportError: + HAS_COMMAND = False + +if HAS_COMMAND: + # =============================== + # Register Command-R + # =============================== + + def data_gen(): + input_ids = torch.Tensor( + [ + [1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082], + [1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082], + ] + ).long() + + attention_mask = torch.Tensor( + [ + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + ] + ).long() + + return dict(input_ids=input_ids, attention_mask=attention_mask) + + # label is needed for casual lm + def data_gen_for_casual_lm(): + data = data_gen() + labels = data["input_ids"].clone() + data["labels"] = labels + return data + + # transform the output to a dict + output_transform_fn = lambda x: x + + # function to get the loss + loss_fn = lambda output: output["last_hidden_state"].mean() + loss_fn_for_casual_lm = lambda output: output["loss"] + loss_fn_for_seq_classification = lambda output: output["logits"].mean() + + config = CohereConfig( + num_hidden_layers=8, + hidden_size=32, + intermediate_size=64, + num_attention_heads=4, + max_position_embeddings=128, + ) + + if hasattr(config, "pad_token_id"): + config.pad_token_id = config.eos_token_id + + # register the following models + # transformers.CohereModel, + # transformers.CohereForCausalLM, + model_zoo.register( + name="transformers_command", + model_fn=lambda: transformers.CohereModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), + ) + model_zoo.register( + name="transformers_command_for_casual_lm", + model_fn=lambda: transformers.CohereForCausalLM(config), + data_gen_fn=data_gen_for_casual_lm, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_casual_lm, + model_attribute=ModelAttribute(has_control_flow=True), + ) diff --git a/tests/test_infer/test_kernels/cuda/test_flash_decoding_attention.py b/tests/test_infer/test_kernels/cuda/test_flash_decoding_attention.py index 0bd398e2e..e9bf24d53 100644 --- a/tests/test_infer/test_kernels/cuda/test_flash_decoding_attention.py +++ b/tests/test_infer/test_kernels/cuda/test_flash_decoding_attention.py @@ -4,7 +4,7 @@ import numpy as np import pytest import torch -from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes +from colossalai.inference.utils import get_alibi_slopes from colossalai.kernel.kernel_loader import InferenceOpsLoader from colossalai.utils import get_current_device from tests.test_infer.test_kernels.triton.test_context_attn_unpad import generate_alibi_mask diff --git a/tests/test_infer/test_kernels/cuda/test_kv_cache_memcpy.py b/tests/test_infer/test_kernels/cuda/test_kv_cache_memcpy.py index d90f64690..c3f2d0144 100644 --- a/tests/test_infer/test_kernels/cuda/test_kv_cache_memcpy.py +++ b/tests/test_infer/test_kernels/cuda/test_kv_cache_memcpy.py @@ -26,7 +26,7 @@ def prepare_data( num_tokens = torch.sum(context_lengths).item() max_seq_len_in_batch = context_lengths.max() - cu_seqlens = F.pad(torch.cumsum(context_lengths, dim=0, dtype=torch.torch.int32), (1, 0)) + cu_seqlens = F.pad(torch.cumsum(context_lengths, dim=0, dtype=torch.int32), (1, 0)) kv_size = (num_tokens, num_kv_heads, HEAD_DIM) key = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) diff --git a/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py b/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py index 8237384c0..57a82647d 100644 --- a/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py +++ b/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py @@ -28,15 +28,22 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, K_H, D, dtype): torch.manual_seed(10) TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN # our crafted op equals to Transformers - x0 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D, dtype=dtype) - x1 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D, dtype=dtype) + x0 = torch.randn(BATCH_SIZE, H, SEQ_LEN, D, dtype=dtype) + x1 = torch.randn(BATCH_SIZE, H, SEQ_LEN, D, dtype=dtype) + + position_ids = torch.arange(TOTAL_TOKENS).reshape((BATCH_SIZE, SEQ_LEN)) + emb = LlamaRotaryEmbedding(D) - cos, sin = emb(x0, TOTAL_TOKENS) + + cos, sin = emb(x0, position_ids) + embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin) + cos = cos.reshape((TOTAL_TOKENS, -1)) + sin = sin.reshape((TOTAL_TOKENS, -1)) cos_2 = cos[:, : D // 2] sin_2 = sin[:, : D // 2] - position_ids = torch.arange(TOTAL_TOKENS) - embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin, position_ids) - embd_stimulated_x = torch_rotary_emb(x0, cos_2, sin_2) + x2 = x0.transpose(1, 2).reshape(TOTAL_TOKENS, H, D) + embd_stimulated_x = torch_rotary_emb(x2, cos_2, sin_2) + embd_stimulated_x = embd_stimulated_x.reshape((BATCH_SIZE, SEQ_LEN, H, D)).transpose(1, 2) assert torch.allclose(embd_x0, embd_stimulated_x) # create data diff --git a/tests/test_infer/test_kernels/triton/test_context_attn_unpad.py b/tests/test_infer/test_kernels/triton/test_context_attn_unpad.py index 9d76858ed..92173ac13 100644 --- a/tests/test_infer/test_kernels/triton/test_context_attn_unpad.py +++ b/tests/test_infer/test_kernels/triton/test_context_attn_unpad.py @@ -2,7 +2,7 @@ import pytest import torch from packaging import version -from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes +from colossalai.inference.utils import get_alibi_slopes from colossalai.kernel.triton import context_attention_unpadded from colossalai.utils import get_current_device from tests.test_infer.test_kernels.triton.kernel_utils import ( diff --git a/tests/test_infer/test_kernels/triton/test_decoding_attn.py b/tests/test_infer/test_kernels/triton/test_decoding_attn.py index 40a6eae58..aa2a7e2b4 100644 --- a/tests/test_infer/test_kernels/triton/test_decoding_attn.py +++ b/tests/test_infer/test_kernels/triton/test_decoding_attn.py @@ -3,7 +3,7 @@ import pytest import torch from packaging import version -from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes +from colossalai.inference.utils import get_alibi_slopes from colossalai.kernel.triton import flash_decoding_attention from colossalai.utils import get_current_device from tests.test_infer.test_kernels.triton.kernel_utils import ( diff --git a/tests/test_infer/test_kernels/triton/test_rotary_embdding_unpad.py b/tests/test_infer/test_kernels/triton/test_rotary_embdding_unpad.py index 570093693..78b7ba81c 100644 --- a/tests/test_infer/test_kernels/triton/test_rotary_embdding_unpad.py +++ b/tests/test_infer/test_kernels/triton/test_rotary_embdding_unpad.py @@ -43,15 +43,19 @@ def torch_rotary_emb(x, cos, sin): def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype, use_new_kcache_layout): TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN # our crafted op equals to Transformers - x0 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D) - x1 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D) + x0 = torch.randn(BATCH_SIZE, H, SEQ_LEN, D, dtype=dtype) + x1 = torch.randn(BATCH_SIZE, H, SEQ_LEN, D, dtype=dtype) emb = LlamaRotaryEmbedding(D) - cos, sin = emb(x0, TOTAL_TOKENS) + position_ids = torch.arange(TOTAL_TOKENS).reshape((BATCH_SIZE, SEQ_LEN)) + cos, sin = emb(x0, position_ids) + embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin) + cos = cos.reshape((TOTAL_TOKENS, -1)) + sin = sin.reshape((TOTAL_TOKENS, -1)) cos_2 = cos[:, :32] sin_2 = sin[:, :32] - position_ids = torch.arange(TOTAL_TOKENS) - embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin, position_ids) - embd_stimulated_x = torch_rotary_emb(x0, cos_2, sin_2) + x2 = x0.transpose(1, 2).reshape(TOTAL_TOKENS, H, D) + embd_stimulated_x = torch_rotary_emb(x2, cos_2, sin_2) + embd_stimulated_x = embd_stimulated_x.reshape((BATCH_SIZE, SEQ_LEN, H, D)).transpose(1, 2) assert torch.allclose(embd_x0, embd_stimulated_x) # create data diff --git a/tests/test_infer/test_models/test_baichuan.py b/tests/test_infer/test_models/test_baichuan.py index 736fab5ff..f24e1bb3f 100644 --- a/tests/test_infer/test_models/test_baichuan.py +++ b/tests/test_infer/test_models/test_baichuan.py @@ -55,7 +55,7 @@ def check_inference_engine(use_engine=False, do_sample=False, use_cuda_kernel=Fa assert inference_engine.generation_config.max_new_tokens == output_len inference_engine.add_request(prompts=inputs) assert inference_engine.request_handler._has_waiting() - generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k) + generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k, max_new_tokens=output_len) outputs = inference_engine.generate(generation_config=generation_config) else: if prompt_template: diff --git a/tests/test_infer/test_models/test_custom_model.py b/tests/test_infer/test_models/test_custom_model.py new file mode 100644 index 000000000..f78731acf --- /dev/null +++ b/tests/test_infer/test_models/test_custom_model.py @@ -0,0 +1,161 @@ +import os +import random + +import numpy as np +import pytest +import torch +import torch.distributed as dist +from torch.multiprocessing import Manager +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, LlamaForCausalLM, LlamaTokenizer + +import colossalai +import colossalai.inference.modeling.policy as policy +from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig +from colossalai.inference.core.engine import InferenceEngine +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn + +# NOTE: To test a model with the inference engine, you need to provide the path to your +# local pretrained model weights in the MODEL_MAP dictionary +MODEL_MAP = { + "baichuan": { + "model": AutoModelForCausalLM, + "tokenizer": AutoTokenizer, + "policy": policy.NoPaddingBaichuanModelInferPolicy, + "model_name_or_path": "baichuan-inc/Baichuan2-13B-Base", # provide the path to local model weights + }, + "llama": { + "model": LlamaForCausalLM, + "tokenizer": LlamaTokenizer, + "policy": policy.NoPaddingLlamaModelInferPolicy, + "model_name_or_path": "meta-llama/Llama-2-70b-hf", + }, +} + +MODELS_TO_TEST = ["llama", "baichuan"] # Specify the models to test + + +@parameterize("model", MODELS_TO_TEST) +@parameterize("prompt_template", [None, "model_specific"]) +@parameterize("do_sample", [False]) +@parameterize("use_cuda_kernel", [True]) +@pytest.mark.largedist +@rerun_if_address_is_in_use() +def test_model(model, prompt_template, do_sample, use_cuda_kernel): + model_path = MODEL_MAP[model]["model_name_or_path"] + if not os.path.exists(model_path): + pytest.skip( + f"There is no local model address included for {model}, please replace this address with a valid one." + ) + + if prompt_template == "model_specific": + prompt_template = model + + model_config = MODEL_MAP[model] + + kwargs1 = { + "model": model, + "use_engine": True, + "prompt_template": prompt_template, + "do_sample": do_sample, + "policy": model_config["policy"](), + "use_cuda_kernel": use_cuda_kernel, + } + + kwargs2 = { + "model": model, + "use_engine": False, + "prompt_template": prompt_template, + "do_sample": do_sample, + "policy": None, + "use_cuda_kernel": use_cuda_kernel, + } + + colossal_tp_1_output = run_engine(1, **kwargs1) + colossal_tp_2_output = run_engine(2, **kwargs1) + transformer_tp_1_output = run_engine(1, **kwargs2) + + for s1, s2, s3 in zip(colossal_tp_1_output, colossal_tp_2_output, transformer_tp_1_output): + assert s1 == s3, f"\nColossalAI TP=1 Output: {s1}\nTransformers Output: {s3}" + assert s1 == s2, f"\nColossalAI TP=1 Output: {s1}\nColossalAI TP=2 Output: {s2}" + + +def run_engine(world_size, **kwargs): + manager = Manager() + result_list = manager.list([-1] * world_size) # Create a shared list + spawn(run_dist, world_size, func_to_run=_run_engine, ret=result_list, **kwargs) + return result_list[0] + + +def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs): + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") + + if ret: + ret[rank] = func_to_run(**kwargs) + else: + func_to_run(**kwargs) + + +def _run_engine(model, use_engine=False, do_sample=False, use_cuda_kernel=False, prompt_template=None, policy=None): + setup_seed(20) + model_config = MODEL_MAP[model] + model_name_or_path = model_config["model_name_or_path"] + tokenizer = model_config["tokenizer"].from_pretrained(model_name_or_path, use_fast=False, trust_remote_code=True) + model = model_config["model"].from_pretrained(model_name_or_path, trust_remote_code=True).half().cuda() + model = model.eval() + + inputs = [ + "Introduce some landmarks in Paris:", + ] + + output_len = 38 + + if do_sample: + top_p = 0.5 + top_k = 50 + else: + top_p = None + top_k = None + + if use_engine: + inference_config = InferenceConfig( + max_output_len=output_len, + prompt_template=prompt_template, + use_cuda_kernel=use_cuda_kernel, + tp_size=dist.get_world_size(), + ) + inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True, model_policy=policy) + assert inference_engine.generation_config.max_new_tokens == output_len + inference_engine.add_request(prompts=inputs) + assert inference_engine.request_handler._has_waiting() + generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k, max_new_tokens=output_len) + outputs = inference_engine.generate(generation_config=generation_config) + else: + if prompt_template: + # apply prompt template + inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs] + tokenizer.pad_token = tokenizer.eos_token + tokenizer.pad_token_id = tokenizer.eos_token_id + inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"] + inputs = inputs.cuda() + generation_config = GenerationConfig( + do_sample=do_sample, + top_p=top_p, + top_k=top_k, + pad_token_id=tokenizer.pad_token_id, + max_new_tokens=output_len, + ) + outputs = model.generate(inputs, generation_config=generation_config) + outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) + return outputs + + +def setup_seed(seed): + torch.manual_seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + + +if __name__ == "__main__": + test_model() diff --git a/tests/test_shardformer/test_model/test_shard_command.py b/tests/test_shardformer/test_model/test_shard_command.py new file mode 100644 index 000000000..b73552cec --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_command.py @@ -0,0 +1,322 @@ +import os + +import pytest +import torch +import torch.distributed as dist +from torch.testing import assert_close + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import PipelineGradientCheckpointConfig +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_all_grad_tensors, + check_loss, + check_output_hidden_state, + check_weight, + get_grad_tensors_for_check, + run_forward_backward_with_hybrid_plugin, + unwrap_model, +) + +os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" + + +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + enable_gradient_checkpointing = test_config.pop("enable_gradient_checkpointing", False) + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config + ) + if enable_gradient_checkpointing: + # org_model.gradient_checkpointing_enable() + sharded_model.unwrap().gradient_checkpointing_enable() + + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) + + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group + + # unwrap model + command_model = unwrap_model(org_model, "CohereModel", "model") + shard_command_model = unwrap_model(sharded_model, "CohereModel", "model") + + row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"] + col_layer_for_check = ["layers[0].self_attn.o_proj"] + # Here we check the grad of layernorm because an all-reduce operation should be performed during sequence parallelism + norm_layer_for_check = ["layers[0].input_layernorm", "layers[1].input_layernorm"] + + # During pipeline parallelism, we cannot get the grad of norm layer during first stage, so we only check this when pp is not enbaled + if stage_manager is None: + norm_layer_for_check.append("norm") + + # Check the grad when using ZeRO-1 and ZeRO-2 + if ( + booster.plugin.zero_stage in [1, 2] + and booster.plugin.shard_config.enable_sequence_parallelism + and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all" + ): + for p1, p2 in zip(command_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]): + working_p = sharded_optimizer._param_store.master_to_working_param[id(p2)] + grads = sharded_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(working_p)) + grad_index = ( + 0 if sharded_optimizer._grad_store._partition_grads else sharded_optimizer._bucket_store.zero_local_rank + ) + grad = grads[grad_index] + sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()] + assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False) + + # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. + grads_to_check = {} + if (stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True)) and booster.plugin.zero_stage == 0: + if test_config["precision"] == "fp32": + atol, rtol = 1e-6, 1e-4 + else: + atol, rtol = 5e-3, 5e-3 + row_layer_grads = get_grad_tensors_for_check( + command_model, + shard_command_model, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False, + ) + col_layer_grads = get_grad_tensors_for_check( + command_model, + shard_command_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, + ) + norm_layer_grads = get_grad_tensors_for_check( + command_model, + shard_command_model, + norm_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, + ) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + grads_to_check.update(norm_layer_grads) + + # optimizer executes step + org_optimizer.step() + sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(ignore_chunk=True): + if test_config["precision"] == "fp32": + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + + if org_model.__class__.__name__ == "CohereModel": + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + # check weights + if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True): + if test_config["precision"] == "fp32": + atol, rtol = 5e-4, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + check_weight( + command_model, + shard_command_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, + ) + + # check grads + check_all_grad_tensors(grads_to_check) + + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring", + "enable_flash_attention": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "split_gather", + "enable_flash_attention": False, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 1, + "sp_size": 2, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + "enable_gradient_checkpointing": True, + "gradient_checkpoint_config": PipelineGradientCheckpointConfig(gradient_checkpointing_ratio=0.5), + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "use_lazy_init": False, + "precision": "fp32", + "enable_gradient_checkpointing": True, + "gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]), + }, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + ], +) +def run_command_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp16", + "zero_stage": 1, + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "pp_style": "interleaved", + "num_model_chunks": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "precision": "fp16", + "zero_stage": 1, + "initial_scale": 1, + "enable_gradient_checkpointing": True, + "gradient_checkpoint_config": PipelineGradientCheckpointConfig( + num_ckpt_layers_per_stage=[0, 1, 2, 2], + ), + }, + ], +) +def run_command_3d_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + +def check_command(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_command_test() + + +def check_command_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_command_3d_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_command(): + spawn(check_command, 4) + + +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_command_3d(): + spawn(check_command_3d, 8) + + +if __name__ == "__main__": + test_command() + test_command_3d() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 1628bf2f3..3a8a1357d 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -120,9 +120,20 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, atol, rtol = 1e-4, 1e-3 else: atol, rtol = 5e-3, 5e-3 - check_weight( - llama_model, shard_llama_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False - ) + try: + check_weight( + llama_model, + shard_llama_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, + ) + except Exception as e: + print(f"Failed config: {test_config}") + raise e # check grads check_all_grad_tensors(grads_to_check) @@ -133,9 +144,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ - { + { # Test ring + Flash attention "tp_size": 2, "pp_size": 1, + "sp_size": 2, "num_microbatches": 1, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "ring", @@ -145,14 +157,16 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "initial_scale": 1, }, - { - "tp_size": 4, - "pp_size": 1, - "num_microbatches": 1, + { # Ulysess + Flash attention + "tp_size": 1, + "pp_size": 2, + "sp_size": 2, + "num_microbatches": 2, "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "split_gather", - "enable_flash_attention": False, + "sequence_parallelism_mode": "all_to_all", + "enable_flash_attention": True, "use_lazy_init": True, + "zero_stage": 1, "precision": "fp16", "initial_scale": 1, }, @@ -164,7 +178,18 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "all_to_all", "use_lazy_init": True, - "zero_stage": 2, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "split_gather", + "enable_flash_attention": False, + "use_lazy_init": True, "precision": "fp16", "initial_scale": 1, }, @@ -213,7 +238,11 @@ def run_llama_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + try: + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + except Exception as e: + print(f"Failed config: {test_config}") + raise e clear_layout_converter() Randomizer.reset_index() @@ -263,7 +292,11 @@ def run_llama_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + try: + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + except Exception as e: + print(f"Failed config: {test_config}") + raise e clear_layout_converter() Randomizer.reset_index() diff --git a/tests/test_shardformer/test_model/test_shard_qwen2.py b/tests/test_shardformer/test_model/test_shard_qwen2.py index d5abd41ae..166b31df9 100644 --- a/tests/test_shardformer/test_model/test_shard_qwen2.py +++ b/tests/test_shardformer/test_model/test_shard_qwen2.py @@ -217,6 +217,7 @@ def check_qwen2_3d(rank, world_size, port): @pytest.mark.skipif(transformers.__version__ < "4.39.1", reason="Requires transformers version 4.39.1 or later") +@pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() def test_qwen2(): @@ -224,6 +225,7 @@ def test_qwen2(): @pytest.mark.skipif(transformers.__version__ < "4.39.1", reason="Requires transformers version 4.39.1 or later") +@pytest.mark.largedist @rerun_if_address_is_in_use() @clear_cache_before_run() def test_qwen2_3d(): diff --git a/version.txt b/version.txt index 667843220..940ac09aa 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.3.8 +0.3.9