mirror of https://github.com/hpcaitech/ColossalAI
Merge branch 'main' of https://github.com/hpcaitech/ColossalAI into main
commit
4b59d874df
|
@ -2,7 +2,7 @@ name: Build on PR
|
||||||
|
|
||||||
on:
|
on:
|
||||||
pull_request:
|
pull_request:
|
||||||
types: [synchronize, opened, reopened, ready_for_review, closed, edited]
|
types: [synchronize, opened, reopened, ready_for_review, closed]
|
||||||
branches:
|
branches:
|
||||||
- "main"
|
- "main"
|
||||||
- "develop"
|
- "develop"
|
||||||
|
|
13
README.md
13
README.md
|
@ -25,6 +25,7 @@
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
## Latest News
|
## Latest News
|
||||||
|
* [2024/06] [Open-Sora Continues Open Source: Generate Any 16-Second 720p HD Video with One Click, Model Weights Ready to Use](https://hpc-ai.com/blog/open-sora-from-hpc-ai-tech-team-continues-open-source-generate-any-16-second-720p-hd-video-with-one-click-model-weights-ready-to-use)
|
||||||
* [2024/05] [Large AI Models Inference Speed Doubled, Colossal-Inference Open Source Release](https://hpc-ai.com/blog/colossal-inference)
|
* [2024/05] [Large AI Models Inference Speed Doubled, Colossal-Inference Open Source Release](https://hpc-ai.com/blog/colossal-inference)
|
||||||
* [2024/04] [Open-Sora Unveils Major Upgrade: Embracing Open Source with Single-Shot 16-Second Video Generation and 720p Resolution](https://hpc-ai.com/blog/open-soras-comprehensive-upgrade-unveiled-embracing-16-second-video-generation-and-720p-resolution-in-open-source)
|
* [2024/04] [Open-Sora Unveils Major Upgrade: Embracing Open Source with Single-Shot 16-Second Video Generation and 720p Resolution](https://hpc-ai.com/blog/open-soras-comprehensive-upgrade-unveiled-embracing-16-second-video-generation-and-720p-resolution-in-open-source)
|
||||||
* [2024/04] [Most cost-effective solutions for inference, fine-tuning and pretraining, tailored to LLaMA3 series](https://hpc-ai.com/blog/most-cost-effective-solutions-for-inference-fine-tuning-and-pretraining-tailored-to-llama3-series)
|
* [2024/04] [Most cost-effective solutions for inference, fine-tuning and pretraining, tailored to LLaMA3 series](https://hpc-ai.com/blog/most-cost-effective-solutions-for-inference-fine-tuning-and-pretraining-tailored-to-llama3-series)
|
||||||
|
@ -32,10 +33,6 @@
|
||||||
* [2024/03] [Open-Sora: Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models](https://hpc-ai.com/blog/open-sora-v1.0)
|
* [2024/03] [Open-Sora: Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models](https://hpc-ai.com/blog/open-sora-v1.0)
|
||||||
* [2024/03] [Open-Sora:Sora Replication Solution with 46% Cost Reduction, Sequence Expansion to Nearly a Million](https://hpc-ai.com/blog/open-sora)
|
* [2024/03] [Open-Sora:Sora Replication Solution with 46% Cost Reduction, Sequence Expansion to Nearly a Million](https://hpc-ai.com/blog/open-sora)
|
||||||
* [2024/01] [Inference Performance Improved by 46%, Open Source Solution Breaks the Length Limit of LLM for Multi-Round Conversations](https://hpc-ai.com/blog/Colossal-AI-SwiftInfer)
|
* [2024/01] [Inference Performance Improved by 46%, Open Source Solution Breaks the Length Limit of LLM for Multi-Round Conversations](https://hpc-ai.com/blog/Colossal-AI-SwiftInfer)
|
||||||
* [2024/01] [Construct Refined 13B Private Model With Just $5000 USD, Upgraded Colossal-AI Llama-2 Open Source](https://hpc-ai.com/blog/colossal-llama-2-13b)
|
|
||||||
* [2023/11] [Enhanced MoE Parallelism, Open-source MoE Model Training Can Be 9 Times More Efficient](https://www.hpc-ai.tech/blog/enhanced-moe-parallelism-open-source-moe-model-training-can-be-9-times-more-efficient)
|
|
||||||
* [2023/09] [One Half-Day of Training Using a Few Hundred Dollars Yields Similar Results to Mainstream Large Models, Open-Source and Commercial-Free Domain-Specific LLM Solution](https://www.hpc-ai.tech/blog/one-half-day-of-training-using-a-few-hundred-dollars-yields-similar-results-to-mainstream-large-models-open-source-and-commercial-free-domain-specific-llm-solution)
|
|
||||||
* [2023/09] [70 Billion Parameter LLaMA2 Model Training Accelerated by 195%](https://www.hpc-ai.tech/blog/70b-llama2-training)
|
|
||||||
* [2023/07] [HPC-AI Tech Raises 22 Million USD in Series A Funding](https://www.hpc-ai.tech/blog/hpc-ai-tech-raises-22-million-usd-in-series-a-funding-to-fuel-team-expansion-and-business-growth)
|
* [2023/07] [HPC-AI Tech Raises 22 Million USD in Series A Funding](https://www.hpc-ai.tech/blog/hpc-ai-tech-raises-22-million-usd-in-series-a-funding-to-fuel-team-expansion-and-business-growth)
|
||||||
|
|
||||||
## Table of Contents
|
## Table of Contents
|
||||||
|
@ -132,13 +129,13 @@ distributed training and inference in a few lines.
|
||||||
|
|
||||||
[Open-Sora](https://github.com/hpcaitech/Open-Sora):Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models
|
[Open-Sora](https://github.com/hpcaitech/Open-Sora):Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models
|
||||||
[[code]](https://github.com/hpcaitech/Open-Sora)
|
[[code]](https://github.com/hpcaitech/Open-Sora)
|
||||||
[[blog]](https://hpc-ai.com/blog/open-soras-comprehensive-upgrade-unveiled-embracing-16-second-video-generation-and-720p-resolution-in-open-source)
|
[[blog]](https://hpc-ai.com/blog/open-sora-from-hpc-ai-tech-team-continues-open-source-generate-any-16-second-720p-hd-video-with-one-click-model-weights-ready-to-use)
|
||||||
[[HuggingFace model weights]](https://huggingface.co/hpcai-tech/Open-Sora)
|
[[Model weights]](https://github.com/hpcaitech/Open-Sora?tab=readme-ov-file#model-weights)
|
||||||
[[Demo]](https://github.com/hpcaitech/Open-Sora?tab=readme-ov-file#-latest-demo)
|
[[Demo]](https://github.com/hpcaitech/Open-Sora?tab=readme-ov-file#-latest-demo)
|
||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
<a href="https://www.youtube.com/watch?v=iDTxepqixuc">
|
<a href="https://youtu.be/ilMQpU71ddI?si=J4JSPzZ03ycYmlki">
|
||||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/sora/sora-demo.png" width="700" />
|
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/sora/opensora-v1.2.png" width="700" />
|
||||||
</a>
|
</a>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
|
|
@ -999,7 +999,9 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
|
), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
|
||||||
|
|
||||||
if enable_sequence_parallelism:
|
if enable_sequence_parallelism:
|
||||||
self.sequence_parallelism_mode = sequence_parallelism_mode if sequence_parallelism_mode is not None else "1"
|
self.sequence_parallelism_mode = (
|
||||||
|
sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all"
|
||||||
|
)
|
||||||
assert (
|
assert (
|
||||||
self.sequence_parallelism_mode in SUPPORT_SP_MODE
|
self.sequence_parallelism_mode in SUPPORT_SP_MODE
|
||||||
), f"Sequence parallelism mode {self.sequence_parallelism_mode} is not in the supported list {SUPPORT_SP_MODE}"
|
), f"Sequence parallelism mode {self.sequence_parallelism_mode} is not in the supported list {SUPPORT_SP_MODE}"
|
||||||
|
@ -1014,19 +1016,13 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
self.sp_size = 1
|
self.sp_size = 1
|
||||||
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
|
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
|
||||||
elif self.sequence_parallelism_mode in ["all_to_all"]:
|
elif self.sequence_parallelism_mode in ["all_to_all"]:
|
||||||
assert (
|
self.sp_size = 1 if sp_size is None else sp_size
|
||||||
tp_size == 1
|
self.dp_size = dist.get_world_size() // (self.sp_size * pp_size * tp_size)
|
||||||
), f"Sequence parallelism mode {self.sequence_parallelism_mode} cannot be used with tensor parallelism"
|
|
||||||
assert (
|
|
||||||
pp_size == 1
|
|
||||||
), f"Sequence parallelism mode {self.sequence_parallelism_mode} cannot be used with pipeline parallelism"
|
|
||||||
self.sp_size = dist.get_world_size() if sp_size is None else sp_size
|
|
||||||
self.dp_size = dist.get_world_size() // (self.sp_size * pp_size)
|
|
||||||
else:
|
else:
|
||||||
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
|
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
|
||||||
assert (
|
assert (
|
||||||
sp_size == 1 or sp_size is None
|
sp_size == 1 or sp_size is None
|
||||||
), f"sp_size can only be set to a >1 number when enable_sequence_parallelism is True"
|
), f"You should not set sp_size when sequence parallelism is not enabled."
|
||||||
self.sp_size = 1
|
self.sp_size = 1
|
||||||
|
|
||||||
self.tp_size = tp_size
|
self.tp_size = tp_size
|
||||||
|
@ -1040,11 +1036,22 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
self.enable_jit_fused = enable_jit_fused
|
self.enable_jit_fused = enable_jit_fused
|
||||||
self.enable_sequence_parallelism = enable_sequence_parallelism
|
self.enable_sequence_parallelism = enable_sequence_parallelism
|
||||||
if dp_outside:
|
if dp_outside:
|
||||||
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
|
(
|
||||||
|
self.dp_axis,
|
||||||
|
self.pp_axis,
|
||||||
|
self.tp_axis,
|
||||||
|
self.sp_axis,
|
||||||
|
) = (
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
3,
|
||||||
|
)
|
||||||
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
|
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
|
||||||
else:
|
else:
|
||||||
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
|
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
|
||||||
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
|
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
|
||||||
|
|
||||||
self.stage_manager = None
|
self.stage_manager = None
|
||||||
self.schedule = None
|
self.schedule = None
|
||||||
self.custom_policy = custom_policy
|
self.custom_policy = custom_policy
|
||||||
|
|
|
@ -314,7 +314,7 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors
|
||||||
use_safetensors (bool): whether to use safetensors to save the checkpoint.
|
use_safetensors (bool): whether to use safetensors to save the checkpoint.
|
||||||
"""
|
"""
|
||||||
# Move all tensors in the state_dict to CPU before saving to avoid serialization issues
|
# Move all tensors in the state_dict to CPU before saving to avoid serialization issues
|
||||||
state_dict_cpu = tree_map(lambda x: x.cpu() if torch.is_tensor(x) else x, state_dict)
|
state_dict_cpu = tree_map(lambda x: x.data.cpu() if torch.is_tensor(x) else x, state_dict)
|
||||||
|
|
||||||
if use_safetensors:
|
if use_safetensors:
|
||||||
assert is_safetensors_available(), "safetensors is not available."
|
assert is_safetensors_available(), "safetensors is not available."
|
||||||
|
|
|
@ -236,7 +236,7 @@ Completion api is used for single sequence request, like answer a question or co
|
||||||
- POST '/chat':
|
- POST '/chat':
|
||||||
Chat api is used for conversation-style request, which often includes dialogue participants(i.e. roles) and corresponding words. Considering the input data are very different from normal inputs, we introduce Chat-Template to match the data format in chat models.
|
Chat api is used for conversation-style request, which often includes dialogue participants(i.e. roles) and corresponding words. Considering the input data are very different from normal inputs, we introduce Chat-Template to match the data format in chat models.
|
||||||
#### chat-template
|
#### chat-template
|
||||||
Followed `transformers`, we add the chat-template argument. As chat models have been trained with very different formats for converting conversations into a single tokenizable string. Using a format that matches the training data is extremely important. This attribute(chat_template) is inclueded in HuggingFace tokenizers, containing a Jinja template that converts conversation histories into a correctly formatted string. You can refer to the [HuggingFace-blog](https://huggingface.co/blog/chat-templates) for more information. We also provide a simple example temlate bellow. Both str or file style chat template are supported.
|
Followed `transformers`, we add the chat-template argument. As chat models have been trained with very different formats for converting conversations into a single tokenizable string. Using a format that matches the training data is extremely important. This attribute(chat_template) is inclueded in HuggingFace tokenizers, containing a Jinja template that converts conversation histories into a correctly formatted string. You can refer to the [HuggingFace-blog](https://huggingface.co/blog/chat-templates) for more information. We also provide a simple example template bellow. Both str or file style chat template are supported.
|
||||||
### Usage
|
### Usage
|
||||||
#### Args for customizing your server
|
#### Args for customizing your server
|
||||||
The configuration for api server contains both serving interface and engine backend.
|
The configuration for api server contains both serving interface and engine backend.
|
||||||
|
|
|
@ -10,6 +10,7 @@ import torch
|
||||||
from transformers.generation import GenerationConfig
|
from transformers.generation import GenerationConfig
|
||||||
|
|
||||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||||
|
from colossalai.inference.utils import can_use_flash_attn2
|
||||||
|
|
||||||
GibiByte = 1024**3
|
GibiByte = 1024**3
|
||||||
|
|
||||||
|
@ -169,7 +170,8 @@ class InferenceConfig(RPC_PARAM):
|
||||||
no_repeat_ngram_size (Optional[int]): If no_repeat_ngram_size > 0, the consecutive tokens of ngram size can only appear once in inference sentences.
|
no_repeat_ngram_size (Optional[int]): If no_repeat_ngram_size > 0, the consecutive tokens of ngram size can only appear once in inference sentences.
|
||||||
repetition_penalty (Optional[float]): The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0.
|
repetition_penalty (Optional[float]): The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0.
|
||||||
ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token.
|
ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token.
|
||||||
n_spec_tokens (int): The maximum number of speculating tokens, defaults to None.
|
use_spec_dec (bool): Indicate whether to use speculative decoding, defaults to False.
|
||||||
|
max_n_spec_tokens (int): The maximum number of speculating tokens, defaults to None.
|
||||||
glimpse_large_kv (bool): Whether to use large KV in drafter model, defaults to False.
|
glimpse_large_kv (bool): Whether to use large KV in drafter model, defaults to False.
|
||||||
block_size (int): The number of blocks in a logical block, defaults to 16.
|
block_size (int): The number of blocks in a logical block, defaults to 16.
|
||||||
tp_size (int): Tensor parallel size, defaults to 1.
|
tp_size (int): Tensor parallel size, defaults to 1.
|
||||||
|
@ -214,6 +216,7 @@ class InferenceConfig(RPC_PARAM):
|
||||||
ignore_eos: bool = False
|
ignore_eos: bool = False
|
||||||
|
|
||||||
# speculative decoding configs
|
# speculative decoding configs
|
||||||
|
use_spec_dec: bool = False
|
||||||
max_n_spec_tokens: int = 5
|
max_n_spec_tokens: int = 5
|
||||||
glimpse_large_kv: bool = False
|
glimpse_large_kv: bool = False
|
||||||
|
|
||||||
|
@ -311,6 +314,16 @@ class InferenceConfig(RPC_PARAM):
|
||||||
|
|
||||||
return GenerationConfig.from_dict(meta_config)
|
return GenerationConfig.from_dict(meta_config)
|
||||||
|
|
||||||
|
def to_model_shard_inference_config(self) -> "ModelShardInferenceConfig":
|
||||||
|
use_flash_attn = can_use_flash_attn2(self.dtype)
|
||||||
|
model_inference_config = ModelShardInferenceConfig(
|
||||||
|
dtype=self.dtype,
|
||||||
|
use_cuda_kernel=self.use_cuda_kernel,
|
||||||
|
use_spec_dec=self.use_spec_dec,
|
||||||
|
use_flash_attn=use_flash_attn,
|
||||||
|
)
|
||||||
|
return model_inference_config
|
||||||
|
|
||||||
def to_rpc_param(self) -> dict:
|
def to_rpc_param(self) -> dict:
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"dtype": str(self.dtype).split(".")[-1],
|
"dtype": str(self.dtype).split(".")[-1],
|
||||||
|
@ -362,3 +375,21 @@ class InferenceConfig(RPC_PARAM):
|
||||||
# Set the attributes from the parsed arguments.
|
# Set the attributes from the parsed arguments.
|
||||||
inference_config = cls(**inference_config_args)
|
inference_config = cls(**inference_config_args)
|
||||||
return inference_config
|
return inference_config
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelShardInferenceConfig:
|
||||||
|
"""
|
||||||
|
Configurations used during init of module for inference modeling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dtype (torch.dtype): The data type for weights and activations.
|
||||||
|
use_cuda_kernel (bool): Whether to use cuda kernel, faster but lose some precision occasionally
|
||||||
|
use_spec_dec (bool): Indicate whether to use speculative decoding.
|
||||||
|
use_flash_attn (bool): Indicate whether to use flash attention.
|
||||||
|
"""
|
||||||
|
|
||||||
|
dtype: torch.dtype = None
|
||||||
|
use_cuda_kernel: bool = False
|
||||||
|
use_spec_dec: bool = False
|
||||||
|
use_flash_attn: bool = False
|
||||||
|
|
|
@ -18,7 +18,7 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||||
from colossalai.accelerator import get_accelerator
|
from colossalai.accelerator import get_accelerator
|
||||||
from colossalai.cluster import ProcessGroupMesh
|
from colossalai.cluster import ProcessGroupMesh
|
||||||
from colossalai.inference.batch_bucket import BatchBucket
|
from colossalai.inference.batch_bucket import BatchBucket
|
||||||
from colossalai.inference.config import InferenceConfig, InputMetaData
|
from colossalai.inference.config import InferenceConfig, InputMetaData, ModelShardInferenceConfig
|
||||||
from colossalai.inference.graph_runner import CUDAGraphRunner
|
from colossalai.inference.graph_runner import CUDAGraphRunner
|
||||||
from colossalai.inference.modeling.policy import model_policy_map
|
from colossalai.inference.modeling.policy import model_policy_map
|
||||||
from colossalai.inference.sampler import search_tokens
|
from colossalai.inference.sampler import search_tokens
|
||||||
|
@ -72,8 +72,9 @@ class InferenceEngine:
|
||||||
|
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.logger = get_dist_logger(__name__)
|
self.logger = get_dist_logger(__name__)
|
||||||
|
self.model_shard_infer_config = inference_config.to_model_shard_inference_config()
|
||||||
|
|
||||||
self.init_model(model_or_path, model_policy)
|
self.init_model(model_or_path, model_policy, self.model_shard_infer_config)
|
||||||
|
|
||||||
self.generation_config = inference_config.to_generation_config(self.model_config)
|
self.generation_config = inference_config.to_generation_config(self.model_config)
|
||||||
self.generation_config_dict = self.generation_config.to_dict()
|
self.generation_config_dict = self.generation_config.to_dict()
|
||||||
|
@ -97,7 +98,8 @@ class InferenceEngine:
|
||||||
self.capture_model(self.k_cache, self.v_cache)
|
self.capture_model(self.k_cache, self.v_cache)
|
||||||
|
|
||||||
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
|
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
|
||||||
self.use_spec_dec = False
|
self.use_spec_dec = self.inference_config.use_spec_dec
|
||||||
|
|
||||||
self.drafter_model = None
|
self.drafter_model = None
|
||||||
self.drafter = None
|
self.drafter = None
|
||||||
self.use_glide = False
|
self.use_glide = False
|
||||||
|
@ -105,13 +107,20 @@ class InferenceEngine:
|
||||||
|
|
||||||
self._verify_args()
|
self._verify_args()
|
||||||
|
|
||||||
def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Union[Policy, Type[Policy]] = None):
|
def init_model(
|
||||||
|
self,
|
||||||
|
model_or_path: Union[nn.Module, str],
|
||||||
|
model_policy: Union[Policy, Type[Policy]] = None,
|
||||||
|
model_shard_infer_config: ModelShardInferenceConfig = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Shard model or/and Load weight
|
Shard model or/and Load weight
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
|
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
|
||||||
model_policy (Policy): the policy to replace the model
|
model_policy (Policy): the policy to replace the model.
|
||||||
|
model_inference_config: the configuration for modeling initialization when inference.
|
||||||
|
model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if isinstance(model_or_path, str):
|
if isinstance(model_or_path, str):
|
||||||
|
@ -124,6 +133,7 @@ class InferenceEngine:
|
||||||
# the model load process in the future.
|
# the model load process in the future.
|
||||||
model = _supported_models[arch].from_pretrained(model_or_path, trust_remote_code=True)
|
model = _supported_models[arch].from_pretrained(model_or_path, trust_remote_code=True)
|
||||||
else:
|
else:
|
||||||
|
# TODO(char-1ee): if the model not supported, use transformers APIs to load and generate
|
||||||
raise ValueError(f"Model {arch} is not supported.")
|
raise ValueError(f"Model {arch} is not supported.")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -167,6 +177,7 @@ class InferenceEngine:
|
||||||
self.model = self._shardformer(
|
self.model = self._shardformer(
|
||||||
model,
|
model,
|
||||||
model_policy,
|
model_policy,
|
||||||
|
model_shard_infer_config,
|
||||||
None,
|
None,
|
||||||
tp_group=tp_group,
|
tp_group=tp_group,
|
||||||
)
|
)
|
||||||
|
@ -187,7 +198,7 @@ class InferenceEngine:
|
||||||
# assert if_has_index_file, "the model path is invalid"
|
# assert if_has_index_file, "the model path is invalid"
|
||||||
# cpt_io.load_model(self.model, model_index_file)
|
# cpt_io.load_model(self.model, model_index_file)
|
||||||
|
|
||||||
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
|
free_gpu_memory, _ = torch.cuda.mem_get_info()
|
||||||
peak_memory = init_gpu_memory - free_gpu_memory
|
peak_memory = init_gpu_memory - free_gpu_memory
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
|
@ -287,6 +298,7 @@ class InferenceEngine:
|
||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
model_policy: Policy,
|
model_policy: Policy,
|
||||||
|
model_shard_infer_config: ModelShardInferenceConfig = None,
|
||||||
stage_manager: PipelineStageManager = None,
|
stage_manager: PipelineStageManager = None,
|
||||||
tp_group: ProcessGroupMesh = None,
|
tp_group: ProcessGroupMesh = None,
|
||||||
) -> nn.Module:
|
) -> nn.Module:
|
||||||
|
@ -312,6 +324,7 @@ class InferenceEngine:
|
||||||
enable_flash_attention=False,
|
enable_flash_attention=False,
|
||||||
enable_jit_fused=False,
|
enable_jit_fused=False,
|
||||||
enable_sequence_parallelism=False,
|
enable_sequence_parallelism=False,
|
||||||
|
extra_kwargs={"model_shard_infer_config": model_shard_infer_config},
|
||||||
)
|
)
|
||||||
shardformer = ShardFormer(shard_config=shardconfig)
|
shardformer = ShardFormer(shard_config=shardconfig)
|
||||||
shard_model, _ = shardformer.optimize(model, model_policy)
|
shard_model, _ = shardformer.optimize(model, model_policy)
|
||||||
|
@ -348,6 +361,7 @@ class InferenceEngine:
|
||||||
engine.clear_spec_dec()
|
engine.clear_spec_dec()
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if drafter_model is None and self.drafter is None:
|
if drafter_model is None and self.drafter is None:
|
||||||
raise ValueError("Drafter not initialized. Please provide a Drafter Model")
|
raise ValueError("Drafter not initialized. Please provide a Drafter Model")
|
||||||
if n_spec_tokens is not None:
|
if n_spec_tokens is not None:
|
||||||
|
@ -452,6 +466,7 @@ class InferenceEngine:
|
||||||
self.k_cache[-1], # use kv cahces of the last layer
|
self.k_cache[-1], # use kv cahces of the last layer
|
||||||
self.v_cache[-1],
|
self.v_cache[-1],
|
||||||
batch.get_sequence_lengths(),
|
batch.get_sequence_lengths(),
|
||||||
|
n_spec_tokens=self.n_spec_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
drafter_out = self.drafter.speculate(
|
drafter_out = self.drafter.speculate(
|
||||||
|
@ -517,19 +532,19 @@ class InferenceEngine:
|
||||||
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
||||||
return_token_ids: bool = False,
|
return_token_ids: bool = False,
|
||||||
generation_config: Optional[GenerationConfig] = None,
|
generation_config: Optional[GenerationConfig] = None,
|
||||||
) -> List[str]:
|
) -> Union[List[str], Tuple[List[str], List[List[int]]]]:
|
||||||
"""
|
"""
|
||||||
Executing the inference step.
|
Executing the inference step.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompts (Union[List[str], optional): Input prompts. Defaults to None.
|
|
||||||
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
|
|
||||||
request_ids (List[int], optional): The request ID. Defaults to None.
|
request_ids (List[int], optional): The request ID. Defaults to None.
|
||||||
return_token_ids (bool): Whether to return output token ids. Defaults to False.
|
prompts (Union[List[str], optional): Input prompts. Defaults to None.
|
||||||
generation_config (GenerationConfig, optional): Huggingface GenerationConfig used for inference. Defaults to None.
|
prompts_token_ids (Union[List[int], torch.Tensor, np.ndarray], optional): token ids of input prompts. Defaults to None.
|
||||||
|
return_token_ids (bool, optional): Whether to return output token ids. Defaults to False.
|
||||||
|
generation_config (Optional[GenerationConfig], optional): Huggingface GenerationConfig used for inference. Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[str]: Inference result returned by one generation.
|
Union[List[str], Tuple[List[str], List[List[int]]]]: Inference result returned by one generation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
|
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
|
||||||
|
|
|
@ -0,0 +1,170 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from colossalai.inference.config import ModelShardInferenceConfig
|
||||||
|
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||||
|
from colossalai.kernel.triton import context_attention_unpadded, flash_decoding_attention
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AttentionMetaData:
|
||||||
|
query_states: torch.Tensor
|
||||||
|
key_states: torch.Tensor
|
||||||
|
value_states: torch.Tensor
|
||||||
|
k_cache: torch.Tensor
|
||||||
|
v_cache: torch.Tensor
|
||||||
|
block_tables: torch.Tensor
|
||||||
|
block_size: int
|
||||||
|
kv_seq_len: int = None
|
||||||
|
sequence_lengths: torch.Tensor = None
|
||||||
|
cu_seqlens: torch.Tensor = None
|
||||||
|
sm_scale: int = None
|
||||||
|
alibi_slopes: torch.Tensor = None
|
||||||
|
output_tensor: torch.Tensor = None
|
||||||
|
use_spec_dec: bool = False
|
||||||
|
use_alibi_attn: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionBackend(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def decode(self, attn_metadatas: AttentionMetaData, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class CudaAttentionBackend(AttentionBackend):
|
||||||
|
"""
|
||||||
|
Attention backend when use_cuda_kernel is True but flash-attn not found. If flash-attn is not found,
|
||||||
|
it uses Triton op `context_attention_unpadded` for prefilling and our cuda op `flash_decoding_attention` for decoding.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, use_flash_attn: bool = False):
|
||||||
|
super().__init__()
|
||||||
|
self.inference_ops = InferenceOpsLoader().load()
|
||||||
|
self.use_flash_attn = use_flash_attn
|
||||||
|
|
||||||
|
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||||
|
if self.use_flash_attn:
|
||||||
|
token_nums = kwargs.get("token_nums", -1)
|
||||||
|
|
||||||
|
from flash_attn import flash_attn_varlen_func
|
||||||
|
|
||||||
|
attn_output = flash_attn_varlen_func(
|
||||||
|
attn_metadata.query_states,
|
||||||
|
attn_metadata.key_states,
|
||||||
|
attn_metadata.value_states,
|
||||||
|
cu_seqlens_q=attn_metadata.cu_seqlens,
|
||||||
|
cu_seqlens_k=attn_metadata.cu_seqlens,
|
||||||
|
max_seqlen_q=attn_metadata.kv_seq_len,
|
||||||
|
max_seqlen_k=attn_metadata.kv_seq_len,
|
||||||
|
dropout_p=0.0,
|
||||||
|
softmax_scale=attn_metadata.sm_scale,
|
||||||
|
causal=True,
|
||||||
|
alibi_slopes=attn_metadata.alibi_slopes,
|
||||||
|
)
|
||||||
|
attn_output = attn_output.view(token_nums, -1)
|
||||||
|
else:
|
||||||
|
attn_output = context_attention_unpadded(
|
||||||
|
q=attn_metadata.query_states,
|
||||||
|
k=attn_metadata.key_states,
|
||||||
|
v=attn_metadata.value_states,
|
||||||
|
k_cache=attn_metadata.k_cache,
|
||||||
|
v_cache=attn_metadata.v_cache,
|
||||||
|
context_lengths=attn_metadata.sequence_lengths,
|
||||||
|
block_tables=attn_metadata.block_tables,
|
||||||
|
block_size=attn_metadata.block_size,
|
||||||
|
output=attn_metadata.output_tensor,
|
||||||
|
alibi_slopes=attn_metadata.alibi_slopes,
|
||||||
|
max_seq_len=attn_metadata.kv_seq_len,
|
||||||
|
sm_scale=attn_metadata.sm_scale,
|
||||||
|
use_new_kcache_layout=True, # use new k-cache layout
|
||||||
|
)
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||||
|
fd_inter_tensor = kwargs.get("fd_inter_tensor", None)
|
||||||
|
output_tensor = attn_metadata.output_tensor
|
||||||
|
self.inference_ops.flash_decoding_attention(
|
||||||
|
output_tensor,
|
||||||
|
attn_metadata.query_states,
|
||||||
|
attn_metadata.k_cache,
|
||||||
|
attn_metadata.v_cache,
|
||||||
|
attn_metadata.sequence_lengths,
|
||||||
|
attn_metadata.block_tables,
|
||||||
|
attn_metadata.block_size,
|
||||||
|
attn_metadata.kv_seq_len,
|
||||||
|
fd_inter_tensor.mid_output,
|
||||||
|
fd_inter_tensor.exp_sums,
|
||||||
|
fd_inter_tensor.max_logits,
|
||||||
|
attn_metadata.alibi_slopes,
|
||||||
|
attn_metadata.sm_scale,
|
||||||
|
)
|
||||||
|
return output_tensor
|
||||||
|
|
||||||
|
|
||||||
|
class TritonAttentionBackend(AttentionBackend):
|
||||||
|
"""
|
||||||
|
Attention backend when use_cuda_kernel is False. It uses pure Triton ops for prefilling and decoding.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||||
|
return context_attention_unpadded(
|
||||||
|
q=attn_metadata.query_states,
|
||||||
|
k=attn_metadata.key_states,
|
||||||
|
v=attn_metadata.value_states,
|
||||||
|
k_cache=attn_metadata.k_cache,
|
||||||
|
v_cache=attn_metadata.v_cache,
|
||||||
|
context_lengths=attn_metadata.sequence_lengths,
|
||||||
|
block_tables=attn_metadata.block_tables,
|
||||||
|
block_size=attn_metadata.block_size,
|
||||||
|
output=attn_metadata.output_tensor,
|
||||||
|
alibi_slopes=attn_metadata.alibi_slopes,
|
||||||
|
max_seq_len=attn_metadata.kv_seq_len,
|
||||||
|
sm_scale=attn_metadata.sm_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||||
|
fd_inter_tensor = kwargs.get("fd_inter_tensor", None)
|
||||||
|
return flash_decoding_attention(
|
||||||
|
q=attn_metadata.query_states,
|
||||||
|
k_cache=attn_metadata.k_cache,
|
||||||
|
v_cache=attn_metadata.v_cache,
|
||||||
|
kv_seq_len=attn_metadata.sequence_lengths,
|
||||||
|
block_tables=attn_metadata.block_tables,
|
||||||
|
block_size=attn_metadata.block_size,
|
||||||
|
max_seq_len_in_batch=attn_metadata.kv_seq_len,
|
||||||
|
output=attn_metadata.output_tensor,
|
||||||
|
mid_output=fd_inter_tensor.mid_output,
|
||||||
|
mid_output_lse=fd_inter_tensor.mid_output_lse,
|
||||||
|
alibi_slopes=attn_metadata.alibi_slopes,
|
||||||
|
sm_scale=attn_metadata.sm_scale,
|
||||||
|
kv_group_num=kwargs.get("num_key_value_groups", 1),
|
||||||
|
q_len=kwargs.get("q_len", 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_attention_backend(
|
||||||
|
model_shard_infer_config: ModelShardInferenceConfig,
|
||||||
|
) -> AttentionBackend:
|
||||||
|
"""
|
||||||
|
Get the attention backend based on the inference configurations. The modeling will use CUDA-kernel-based backend
|
||||||
|
for attention module calculation only when:
|
||||||
|
1. using CUDA kernel (use_cuda_kernel=True)
|
||||||
|
2. can use flash attention (flash-attn installed and dtype is fp16 or bf16)
|
||||||
|
3. not using speculative decoding (currently cuda kernel not support speculative decoding)
|
||||||
|
Otherwise, use Triton attention backend. If found flash-attn not installed while `use_cuda_kernel` is True,
|
||||||
|
the Triton backend will use a new k cache layout for Triton kernels.
|
||||||
|
"""
|
||||||
|
# Currently only triton kernels support speculative decoding
|
||||||
|
if model_shard_infer_config.use_spec_dec:
|
||||||
|
return TritonAttentionBackend()
|
||||||
|
|
||||||
|
if model_shard_infer_config.use_cuda_kernel:
|
||||||
|
return CudaAttentionBackend(model_shard_infer_config.use_flash_attn)
|
||||||
|
|
||||||
|
return TritonAttentionBackend()
|
|
@ -0,0 +1,146 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from colossalai.inference.config import ModelShardInferenceConfig
|
||||||
|
from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData
|
||||||
|
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||||
|
from colossalai.kernel.triton import copy_k_to_blocked_cache, decoding_fused_rotary_embedding, rotary_embedding
|
||||||
|
|
||||||
|
|
||||||
|
class PreAttentionBackend(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class CudaPreAttentionBackend(PreAttentionBackend):
|
||||||
|
"""
|
||||||
|
CudaPreAttentionBackend handles KV cache initialization and positional encoding for CudaAttentionBackend.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, use_flash_attn: bool):
|
||||||
|
super().__init__()
|
||||||
|
self.inference_ops = InferenceOpsLoader().load()
|
||||||
|
self.use_flash_attn = use_flash_attn
|
||||||
|
|
||||||
|
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||||
|
if self.use_flash_attn:
|
||||||
|
if not attn_metadata.use_alibi_attn:
|
||||||
|
self.inference_ops.rotary_embedding(
|
||||||
|
attn_metadata.query_states,
|
||||||
|
attn_metadata.key_states,
|
||||||
|
kwargs.get("cos", None),
|
||||||
|
kwargs.get("sin", None),
|
||||||
|
kwargs.get("high_precision", False),
|
||||||
|
)
|
||||||
|
self.inference_ops.context_kv_cache_memcpy(
|
||||||
|
attn_metadata.key_states,
|
||||||
|
attn_metadata.value_states,
|
||||||
|
attn_metadata.k_cache,
|
||||||
|
attn_metadata.v_cache,
|
||||||
|
attn_metadata.sequence_lengths,
|
||||||
|
attn_metadata.cu_seqlens,
|
||||||
|
attn_metadata.block_tables,
|
||||||
|
attn_metadata.kv_seq_len,
|
||||||
|
)
|
||||||
|
elif not attn_metadata.use_alibi_attn:
|
||||||
|
rotary_embedding(
|
||||||
|
attn_metadata.query_states,
|
||||||
|
attn_metadata.key_states,
|
||||||
|
kwargs.get("cos", None),
|
||||||
|
kwargs.get("sin", None),
|
||||||
|
)
|
||||||
|
|
||||||
|
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||||
|
if not attn_metadata.use_alibi_attn:
|
||||||
|
self.inference_ops.rotary_embedding_and_cache_copy(
|
||||||
|
attn_metadata.query_states,
|
||||||
|
attn_metadata.key_states,
|
||||||
|
attn_metadata.value_states,
|
||||||
|
kwargs.get("cos", None),
|
||||||
|
kwargs.get("sin", None),
|
||||||
|
attn_metadata.k_cache,
|
||||||
|
attn_metadata.v_cache,
|
||||||
|
attn_metadata.sequence_lengths,
|
||||||
|
attn_metadata.block_tables,
|
||||||
|
kwargs.get("high_precision", None),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.inference_ops.decode_kv_cache_memcpy(
|
||||||
|
attn_metadata.key_states,
|
||||||
|
attn_metadata.value_states,
|
||||||
|
attn_metadata.k_cache,
|
||||||
|
attn_metadata.v_cache,
|
||||||
|
attn_metadata.sequence_lengths,
|
||||||
|
attn_metadata.block_tables,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TritonPreAttentionBackend(PreAttentionBackend):
|
||||||
|
"""
|
||||||
|
TritonPreAttentionBackend handles KV cache initialization and positional encoding for TritonAttentionBackend.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||||
|
if not attn_metadata.use_alibi_attn:
|
||||||
|
rotary_embedding(
|
||||||
|
attn_metadata.query_states,
|
||||||
|
attn_metadata.key_states,
|
||||||
|
kwargs.get("cos", None),
|
||||||
|
kwargs.get("sin", None),
|
||||||
|
)
|
||||||
|
|
||||||
|
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||||
|
if not attn_metadata.use_spec_dec and not attn_metadata.use_alibi_attn:
|
||||||
|
decoding_fused_rotary_embedding(
|
||||||
|
attn_metadata.query_states,
|
||||||
|
attn_metadata.key_states,
|
||||||
|
attn_metadata.value_states,
|
||||||
|
kwargs.get("cos", None),
|
||||||
|
kwargs.get("sin", None),
|
||||||
|
attn_metadata.k_cache,
|
||||||
|
attn_metadata.v_cache,
|
||||||
|
attn_metadata.block_tables,
|
||||||
|
attn_metadata.sequence_lengths,
|
||||||
|
)
|
||||||
|
else: # else if using speculative decoding
|
||||||
|
if not attn_metadata.use_alibi_attn:
|
||||||
|
rotary_embedding(
|
||||||
|
attn_metadata.query_states,
|
||||||
|
attn_metadata.key_states,
|
||||||
|
kwargs.get("cos", None),
|
||||||
|
kwargs.get("sin", None),
|
||||||
|
)
|
||||||
|
copy_k_to_blocked_cache(
|
||||||
|
attn_metadata.key_states,
|
||||||
|
attn_metadata.k_cache,
|
||||||
|
kv_lengths=attn_metadata.sequence_lengths,
|
||||||
|
block_tables=attn_metadata.block_tables,
|
||||||
|
n=kwargs.get("q_len", 1),
|
||||||
|
)
|
||||||
|
copy_k_to_blocked_cache(
|
||||||
|
attn_metadata.value_states,
|
||||||
|
attn_metadata.v_cache,
|
||||||
|
kv_lengths=attn_metadata.sequence_lengths,
|
||||||
|
block_tables=attn_metadata.block_tables,
|
||||||
|
n=kwargs.get("q_len", 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_pre_attention_backend(
|
||||||
|
model_shard_infer_config: ModelShardInferenceConfig,
|
||||||
|
) -> PreAttentionBackend:
|
||||||
|
"""
|
||||||
|
Get the backend for pre-attention computations, including potisional encoding like
|
||||||
|
RoPE and KV cache initialization. It adopt the same selection logic as attention_backend/get_attention_backend.
|
||||||
|
"""
|
||||||
|
if model_shard_infer_config.use_spec_dec:
|
||||||
|
return TritonPreAttentionBackend()
|
||||||
|
|
||||||
|
if model_shard_infer_config.use_cuda_kernel:
|
||||||
|
return CudaPreAttentionBackend(model_shard_infer_config.use_flash_attn)
|
||||||
|
|
||||||
|
return TritonPreAttentionBackend()
|
|
@ -15,25 +15,10 @@ class BaichuanLMHeadLinear1D_Col(Linear1D_Col):
|
||||||
module.in_features = module.weight.size(1)
|
module.in_features = module.weight.size(1)
|
||||||
module.out_features = module.weight.size(0)
|
module.out_features = module.weight.size(0)
|
||||||
module.bias = None
|
module.bias = None
|
||||||
module.weight.data = nn.functional.normalize(module.weight)
|
module.weight.data = nn.functional.normalize(
|
||||||
|
module.weight
|
||||||
return Linear1D_Col.from_native_module(
|
) # TODO(lry89757) This behavior may not apply to lazy init. When we use lazy init, the weight of shardformer is not the real weight.
|
||||||
module,
|
# So we should rewrite our own load_from_state_dict of `BaichuanLMHeadLinear1D_Col` to fix this potential issue.
|
||||||
process_group,
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BaichuanWpackLinear1D_Col(Linear1D_Col):
|
|
||||||
@staticmethod
|
|
||||||
def from_native_module(
|
|
||||||
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
|
||||||
) -> ParallelModule:
|
|
||||||
in_features = module.in_features * 3
|
|
||||||
out_features = module.out_features // 3
|
|
||||||
module.weight.data = module.weight.view(3, out_features, -1).transpose(0, 1).reshape(out_features, in_features)
|
|
||||||
module.bias = None
|
|
||||||
|
|
||||||
return Linear1D_Col.from_native_module(
|
return Linear1D_Col.from_native_module(
|
||||||
module,
|
module,
|
||||||
|
|
|
@ -6,11 +6,7 @@ from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers.cache_utils import Cache, DynamicCache
|
from transformers.cache_utils import Cache, DynamicCache, StaticCache
|
||||||
from transformers.modeling_attn_mask_utils import (
|
|
||||||
_prepare_4d_causal_attention_mask,
|
|
||||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
|
||||||
)
|
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import (
|
||||||
LlamaAttention,
|
LlamaAttention,
|
||||||
|
@ -137,6 +133,7 @@ def glide_llama_model_forward(
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
|
@ -147,57 +144,43 @@ def glide_llama_model_forward(
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
# retrieve input_ids and inputs_embeds
|
# retrieve input_ids and inputs_embeds
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
raise ValueError(
|
||||||
elif input_ids is not None:
|
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
||||||
batch_size, seq_length = input_ids.shape[:2]
|
|
||||||
elif inputs_embeds is not None:
|
|
||||||
batch_size, seq_length = inputs_embeds.shape[:2]
|
|
||||||
else:
|
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
||||||
|
|
||||||
past_key_values_length = 0
|
|
||||||
if use_cache:
|
|
||||||
use_legacy_cache = not isinstance(past_key_values, Cache)
|
|
||||||
if use_legacy_cache:
|
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
|
||||||
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
|
||||||
|
|
||||||
if position_ids is None:
|
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
||||||
position_ids = torch.arange(
|
|
||||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
|
||||||
)
|
)
|
||||||
position_ids = position_ids.unsqueeze(0)
|
|
||||||
|
if self.gradient_checkpointing and self.training and use_cache:
|
||||||
|
logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.")
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
if self._use_flash_attention_2:
|
past_seen_tokens = 0
|
||||||
# 2d mask is passed through the layers
|
if use_cache: # kept for BC (cache positions)
|
||||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
if not isinstance(past_key_values, StaticCache):
|
||||||
elif self._use_sdpa and not output_attentions:
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
past_seen_tokens = past_key_values.get_seq_length()
|
||||||
# the manual implementation that requires a 4D causal mask in all cases.
|
|
||||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
if cache_position is None:
|
||||||
attention_mask,
|
if isinstance(past_key_values, StaticCache):
|
||||||
(batch_size, seq_length),
|
raise ValueError("cache_position is a required argument when using StaticCache.")
|
||||||
inputs_embeds,
|
cache_position = torch.arange(
|
||||||
past_key_values_length,
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||||
)
|
|
||||||
else:
|
|
||||||
# 4d mask is passed through the layers
|
|
||||||
attention_mask = _prepare_4d_causal_attention_mask(
|
|
||||||
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
|
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
|
||||||
|
|
||||||
# embed positions
|
# embed positions
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
# decoder layers
|
# decoder layers
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
next_decoder_cache = () if use_cache else None
|
next_decoder_cache = None
|
||||||
|
|
||||||
for decoder_layer in self.layers:
|
for decoder_layer in self.layers:
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
|
@ -212,6 +195,7 @@ def glide_llama_model_forward(
|
||||||
past_key_value=past_key_values,
|
past_key_value=past_key_values,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
@ -230,7 +214,9 @@ def glide_llama_model_forward(
|
||||||
|
|
||||||
next_cache = None
|
next_cache = None
|
||||||
if use_cache:
|
if use_cache:
|
||||||
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
next_cache = (
|
||||||
|
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
|
||||||
|
)
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||||
return BaseModelOutputWithPast(
|
return BaseModelOutputWithPast(
|
||||||
|
@ -333,7 +319,8 @@ class LlamaCrossAttention(nn.Module):
|
||||||
query_states = query_states.view(bsz, -1, self.large_num_heads, self.large_head_dim).transpose(1, 2)
|
query_states = query_states.view(bsz, -1, self.large_num_heads, self.large_head_dim).transpose(1, 2)
|
||||||
|
|
||||||
# for RoPE
|
# for RoPE
|
||||||
cos, sin = self.rotary_emb(query_states, seq_len=kv_seq_len + 32)
|
position_ids = position_ids + glide_input.n_spec_tokens
|
||||||
|
cos, sin = self.rotary_emb(query_states, position_ids)
|
||||||
query_states = apply_single_rotary_pos_emb(query_states, cos, sin, position_ids)
|
query_states = apply_single_rotary_pos_emb(query_states, cos, sin, position_ids)
|
||||||
query_states = query_states.transpose(1, 2)
|
query_states = query_states.transpose(1, 2)
|
||||||
query_states = query_states.reshape(-1, self.large_num_heads, self.large_head_dim)
|
query_states = query_states.reshape(-1, self.large_num_heads, self.large_head_dim)
|
||||||
|
|
|
@ -1,68 +1,27 @@
|
||||||
# This code is adapted from huggingface baichuan model: hhttps://huggingface.co/baichuan-inc/Baichuan2-13B-Base/blob/main/modeling_baichuan.py
|
# This code is adapted from huggingface baichuan model: hhttps://huggingface.co/baichuan-inc/Baichuan2-13B-Base/blob/main/modeling_baichuan.py
|
||||||
import itertools
|
|
||||||
import math
|
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
from colossalai.accelerator import get_accelerator
|
||||||
|
from colossalai.inference.config import ModelShardInferenceConfig
|
||||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||||
|
from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData, get_attention_backend
|
||||||
|
from colossalai.inference.modeling.backends.pre_attention_backend import get_pre_attention_backend
|
||||||
from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaMLP
|
from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaMLP
|
||||||
|
from colossalai.inference.utils import get_alibi_slopes
|
||||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||||
from colossalai.kernel.triton import (
|
from colossalai.kernel.triton import rms_layernorm
|
||||||
context_attention_unpadded,
|
|
||||||
copy_k_to_blocked_cache,
|
|
||||||
decoding_fused_rotary_embedding,
|
|
||||||
flash_decoding_attention,
|
|
||||||
rms_layernorm,
|
|
||||||
rotary_embedding,
|
|
||||||
)
|
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.shardformer.layer.parallel_module import ParallelModule
|
from colossalai.shardformer.layer.parallel_module import ParallelModule
|
||||||
from colossalai.tensor.d_tensor import Layout, distribute_tensor, is_distributed_tensor
|
from colossalai.tensor.d_tensor import is_distributed_tensor
|
||||||
|
|
||||||
logger = get_dist_logger(__name__)
|
|
||||||
|
|
||||||
try:
|
|
||||||
from flash_attn import flash_attn_varlen_func
|
|
||||||
|
|
||||||
use_flash_attn2 = True
|
|
||||||
except ImportError:
|
|
||||||
use_flash_attn2 = False
|
|
||||||
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
|
|
||||||
|
|
||||||
logger = get_dist_logger(__name__)
|
|
||||||
|
|
||||||
try:
|
|
||||||
from flash_attn import flash_attn_varlen_func
|
|
||||||
|
|
||||||
use_flash_attn2 = True
|
|
||||||
except ImportError:
|
|
||||||
use_flash_attn2 = False
|
|
||||||
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
|
|
||||||
|
|
||||||
inference_ops = InferenceOpsLoader().load()
|
inference_ops = InferenceOpsLoader().load()
|
||||||
|
|
||||||
logger = get_dist_logger(__name__)
|
logger = get_dist_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# alibi slopes calculation adapted from https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/models/bloom/modeling_bloom.py#L57
|
|
||||||
def get_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor:
|
|
||||||
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
|
|
||||||
base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, device=device)
|
|
||||||
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32, device=device)
|
|
||||||
slopes = torch.pow(base, powers)
|
|
||||||
if closest_power_of_2 != num_heads:
|
|
||||||
extra_base = torch.tensor(
|
|
||||||
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, device=device
|
|
||||||
)
|
|
||||||
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
|
|
||||||
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32, device=device)
|
|
||||||
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
|
||||||
return slopes
|
|
||||||
|
|
||||||
|
|
||||||
def baichuan_rmsnorm_forward(
|
def baichuan_rmsnorm_forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
@ -96,23 +55,19 @@ class NopadBaichuanAttention(ParallelModule):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
attn_qproj_w: torch.Tensor = None,
|
W_pack: ParallelModule = None,
|
||||||
attn_kproj_w: torch.Tensor = None,
|
|
||||||
attn_vproj_w: torch.Tensor = None,
|
|
||||||
attn_oproj: ParallelModule = None,
|
attn_oproj: ParallelModule = None,
|
||||||
num_heads: int = None,
|
num_heads: int = None,
|
||||||
hidden_size: int = None,
|
hidden_size: int = None,
|
||||||
|
model_shard_infer_config: ModelShardInferenceConfig = None,
|
||||||
process_group: ProcessGroup = None,
|
process_group: ProcessGroup = None,
|
||||||
helper_layout: Layout = None,
|
|
||||||
):
|
):
|
||||||
"""This layer will replace the BaichuanAttention.
|
"""This layer will replace the BaichuanAttention.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config (BaichuanConfig): Holding the Baichuan model config.
|
config (BaichuanConfig): Holding the Baichuan model config.
|
||||||
attn_qproj_w (torch.Tensor, optional): The transposed q_proj weight. Defaults to None.
|
W_pack (ParallelModule, optional): The packed weight. Defaults to None.
|
||||||
attn_kproj_w (torch.Tensor, optional): The transposed k_proj weight. Defaults to None.
|
attn_oproj (Linear1D_Row, optional): The Linear1D_Row o_proj. Defaults to None.
|
||||||
attn_vproj_w (torch.Tensor, optional): The transposed v_proj weight. Defaults to None.
|
|
||||||
attn_oproj (Linear1D_Row, optional): The Linear1D_Row o_proj weight. Defaults to None.
|
|
||||||
"""
|
"""
|
||||||
ParallelModule.__init__(self)
|
ParallelModule.__init__(self)
|
||||||
self.o_proj = attn_oproj
|
self.o_proj = attn_oproj
|
||||||
|
@ -122,10 +77,10 @@ class NopadBaichuanAttention(ParallelModule):
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.head_dim = self.hidden_size // self.num_heads
|
self.head_dim = self.hidden_size // self.num_heads
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
qkv_weight_list = [attn_qproj_w.transpose(0, 1), attn_kproj_w.transpose(0, 1), attn_vproj_w.transpose(0, 1)]
|
self.W_pack = W_pack
|
||||||
self.qkv_weight = nn.Parameter(torch.stack(qkv_weight_list, dim=0))
|
self.use_cuda_kernel = model_shard_infer_config.use_cuda_kernel
|
||||||
|
self.attention_backend = get_attention_backend(model_shard_infer_config)
|
||||||
self.helper_layout = helper_layout
|
self.pre_attention_backend = get_pre_attention_backend(model_shard_infer_config)
|
||||||
|
|
||||||
self.alibi_slopes = None
|
self.alibi_slopes = None
|
||||||
self.use_alibi_attn = False
|
self.use_alibi_attn = False
|
||||||
|
@ -133,9 +88,9 @@ class NopadBaichuanAttention(ParallelModule):
|
||||||
if config.hidden_size == 5120:
|
if config.hidden_size == 5120:
|
||||||
slopes_start = self.process_group.rank() * num_heads
|
slopes_start = self.process_group.rank() * num_heads
|
||||||
self.use_alibi_attn = True
|
self.use_alibi_attn = True
|
||||||
self.alibi_slopes = get_alibi_slopes(config.num_attention_heads, device=attn_qproj_w.device)[
|
self.alibi_slopes = get_alibi_slopes(
|
||||||
slopes_start : slopes_start + num_heads
|
config.num_attention_heads, device=get_accelerator().get_current_device()
|
||||||
].contiguous()
|
)[slopes_start : slopes_start + num_heads].contiguous()
|
||||||
self.alibi_slopes = nn.Parameter(self.alibi_slopes)
|
self.alibi_slopes = nn.Parameter(self.alibi_slopes)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -149,76 +104,22 @@ class NopadBaichuanAttention(ParallelModule):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
config = module.config
|
config = module.config
|
||||||
q_proj_w, k_proj_w, v_proj_w = module.W_pack.weight.view((module.hidden_size, 3, -1)).transpose(0, 1)
|
W_pack = module.W_pack
|
||||||
|
|
||||||
attn_qproj_w = q_proj_w
|
|
||||||
attn_kproj_w = k_proj_w
|
|
||||||
attn_vproj_w = v_proj_w
|
|
||||||
attn_oproj = module.o_proj
|
attn_oproj = module.o_proj
|
||||||
|
model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
|
||||||
helper_layout = (
|
|
||||||
module.W_pack.weight.dist_layout
|
|
||||||
) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict)
|
|
||||||
|
|
||||||
attn_layer = NopadBaichuanAttention(
|
attn_layer = NopadBaichuanAttention(
|
||||||
config=config,
|
config=config,
|
||||||
attn_qproj_w=attn_qproj_w,
|
W_pack=W_pack,
|
||||||
attn_kproj_w=attn_kproj_w,
|
|
||||||
attn_vproj_w=attn_vproj_w,
|
|
||||||
attn_oproj=attn_oproj,
|
attn_oproj=attn_oproj,
|
||||||
|
model_shard_infer_config=model_shard_infer_config,
|
||||||
num_heads=module.num_heads,
|
num_heads=module.num_heads,
|
||||||
hidden_size=module.hidden_size,
|
hidden_size=module.hidden_size,
|
||||||
process_group=process_group,
|
process_group=process_group,
|
||||||
helper_layout=helper_layout,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return attn_layer
|
return attn_layer
|
||||||
|
|
||||||
def _load_from_state_dict(
|
|
||||||
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
|
||||||
):
|
|
||||||
for hook in self._load_state_dict_pre_hooks.values():
|
|
||||||
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
|
||||||
|
|
||||||
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
|
|
||||||
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
|
|
||||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
|
||||||
|
|
||||||
key = "qkv_weight"
|
|
||||||
qkv_w = state_dict[prefix + "W_pack.weight"]
|
|
||||||
|
|
||||||
in_features = qkv_w.size(1)
|
|
||||||
out_features = qkv_w.size(0) // 3
|
|
||||||
|
|
||||||
qkv_w.data = qkv_w.view((3, out_features, -1)).transpose(0, 1).reshape(out_features, in_features * 3)
|
|
||||||
|
|
||||||
device_mesh = self.helper_layout.device_mesh
|
|
||||||
sharding_spec = self.helper_layout.sharding_spec
|
|
||||||
qkv_w = distribute_tensor(qkv_w, device_mesh, sharding_spec)
|
|
||||||
|
|
||||||
qkv_w = qkv_w.transpose(0, 1).reshape(3, in_features, -1)
|
|
||||||
input_param = nn.Parameter(
|
|
||||||
qkv_w
|
|
||||||
) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
|
|
||||||
|
|
||||||
param = local_state[key]
|
|
||||||
|
|
||||||
try:
|
|
||||||
with torch.no_grad():
|
|
||||||
param.copy_(input_param)
|
|
||||||
except Exception as ex:
|
|
||||||
error_msgs.append(
|
|
||||||
'While copying the parameter named "{}", '
|
|
||||||
"whose dimensions in the model are {} and "
|
|
||||||
"whose dimensions in the checkpoint are {}, "
|
|
||||||
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
|
|
||||||
)
|
|
||||||
|
|
||||||
strict = False # to avoid unexpected_keys
|
|
||||||
super()._load_from_state_dict(
|
|
||||||
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
@ -234,7 +135,6 @@ class NopadBaichuanAttention(ParallelModule):
|
||||||
kv_seq_len: int = 0,
|
kv_seq_len: int = 0,
|
||||||
output_tensor: torch.Tensor = None,
|
output_tensor: torch.Tensor = None,
|
||||||
sm_scale: int = None,
|
sm_scale: int = None,
|
||||||
use_cuda_kernel: bool = True,
|
|
||||||
cu_seqlens: torch.Tensor = None,
|
cu_seqlens: torch.Tensor = None,
|
||||||
high_precision: bool = False,
|
high_precision: bool = False,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
@ -253,133 +153,58 @@ class NopadBaichuanAttention(ParallelModule):
|
||||||
kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
|
kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
|
||||||
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
|
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
|
||||||
sm_scale (int, optional): Used for flash attention. Defaults to None.
|
sm_scale (int, optional): Used for flash attention. Defaults to None.
|
||||||
use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
|
|
||||||
cu_seqlens(torch.Tensor, optional): Holding the cumulative sum of sequence length.
|
cu_seqlens(torch.Tensor, optional): Holding the cumulative sum of sequence length.
|
||||||
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
|
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
token_nums = hidden_states.size(0)
|
token_nums = hidden_states.size(0)
|
||||||
# fused qkv
|
|
||||||
hidden_states = hidden_states.expand(3, -1, -1)
|
proj = self.W_pack(hidden_states)
|
||||||
query_states, key_states, value_states = (
|
proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
|
||||||
torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0)
|
query_states = proj[0].view(token_nums, self.num_heads, self.head_dim)
|
||||||
)
|
key_states = proj[1].view(token_nums, self.num_heads, self.head_dim)
|
||||||
|
value_states = proj[2].view(token_nums, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
block_size = k_cache.size(-2)
|
block_size = k_cache.size(-2)
|
||||||
|
|
||||||
if is_prompts:
|
attn_metadata = AttentionMetaData(
|
||||||
if not is_verifier and use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2:
|
query_states=query_states,
|
||||||
# flash attn 2 currently only supports FP16/BF16.
|
key_states=key_states,
|
||||||
if not self.use_alibi_attn:
|
value_states=value_states,
|
||||||
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision)
|
|
||||||
inference_ops.context_kv_cache_memcpy(
|
|
||||||
key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len
|
|
||||||
)
|
|
||||||
attn_output = flash_attn_varlen_func(
|
|
||||||
query_states,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
cu_seqlens_q=cu_seqlens,
|
|
||||||
cu_seqlens_k=cu_seqlens,
|
|
||||||
max_seqlen_q=kv_seq_len,
|
|
||||||
max_seqlen_k=kv_seq_len,
|
|
||||||
dropout_p=0.0,
|
|
||||||
softmax_scale=sm_scale,
|
|
||||||
causal=True,
|
|
||||||
alibi_slopes=self.alibi_slopes,
|
|
||||||
)
|
|
||||||
attn_output = attn_output.view(token_nums, -1)
|
|
||||||
else:
|
|
||||||
if not self.use_alibi_attn:
|
|
||||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
|
||||||
attn_output = context_attention_unpadded(
|
|
||||||
q=query_states,
|
|
||||||
k=key_states,
|
|
||||||
v=value_states,
|
|
||||||
k_cache=k_cache,
|
k_cache=k_cache,
|
||||||
v_cache=v_cache,
|
v_cache=v_cache,
|
||||||
context_lengths=sequence_lengths,
|
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
output=output_tensor,
|
kv_seq_len=kv_seq_len,
|
||||||
alibi_slopes=self.alibi_slopes,
|
sequence_lengths=sequence_lengths,
|
||||||
max_seq_len=kv_seq_len,
|
|
||||||
sm_scale=sm_scale,
|
sm_scale=sm_scale,
|
||||||
use_new_kcache_layout=use_cuda_kernel,
|
alibi_slopes=self.alibi_slopes,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
output_tensor=output_tensor,
|
||||||
|
use_spec_dec=is_verifier,
|
||||||
|
use_alibi_attn=self.use_alibi_attn,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
|
if is_prompts: # prefilling stage
|
||||||
|
self.pre_attention_backend.prefill(
|
||||||
|
attn_metadata,
|
||||||
|
cos=cos_sin[0],
|
||||||
|
sin=cos_sin[1],
|
||||||
|
high_precision=high_precision,
|
||||||
|
)
|
||||||
|
attn_output = self.attention_backend.prefill(
|
||||||
|
attn_metadata,
|
||||||
|
token_nums=token_nums,
|
||||||
|
)
|
||||||
|
else: # decoding stage
|
||||||
q_len = tokens_to_verify + 1 if is_verifier else 1
|
q_len = tokens_to_verify + 1 if is_verifier else 1
|
||||||
|
|
||||||
if use_cuda_kernel:
|
self.pre_attention_backend.decode(
|
||||||
if not self.use_alibi_attn:
|
attn_metadata,
|
||||||
inference_ops.rotary_embedding_and_cache_copy(
|
q_len=q_len,
|
||||||
query_states,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
cos_sin[0],
|
|
||||||
cos_sin[1],
|
|
||||||
k_cache,
|
|
||||||
v_cache,
|
|
||||||
sequence_lengths,
|
|
||||||
block_tables,
|
|
||||||
high_precision,
|
|
||||||
)
|
)
|
||||||
else:
|
attn_output = self.attention_backend.decode(
|
||||||
inference_ops.decode_kv_cache_memcpy(
|
attn_metadata,
|
||||||
key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables
|
fd_inter_tensor=fd_inter_tensor,
|
||||||
)
|
|
||||||
inference_ops.flash_decoding_attention(
|
|
||||||
output_tensor,
|
|
||||||
query_states,
|
|
||||||
k_cache,
|
|
||||||
v_cache,
|
|
||||||
sequence_lengths,
|
|
||||||
block_tables,
|
|
||||||
block_size,
|
|
||||||
kv_seq_len,
|
|
||||||
fd_inter_tensor.mid_output,
|
|
||||||
fd_inter_tensor.exp_sums,
|
|
||||||
fd_inter_tensor.max_logits,
|
|
||||||
self.alibi_slopes,
|
|
||||||
sm_scale,
|
|
||||||
)
|
|
||||||
attn_output = output_tensor
|
|
||||||
else:
|
|
||||||
if not is_verifier and not self.use_alibi_attn:
|
|
||||||
decoding_fused_rotary_embedding(
|
|
||||||
query_states,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
cos_sin[0],
|
|
||||||
cos_sin[1],
|
|
||||||
k_cache,
|
|
||||||
v_cache,
|
|
||||||
block_tables,
|
|
||||||
sequence_lengths,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if not self.use_alibi_attn:
|
|
||||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
|
||||||
copy_k_to_blocked_cache(
|
|
||||||
key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
|
|
||||||
)
|
|
||||||
copy_k_to_blocked_cache(
|
|
||||||
value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = flash_decoding_attention(
|
|
||||||
q=query_states,
|
|
||||||
k_cache=k_cache,
|
|
||||||
v_cache=v_cache,
|
|
||||||
kv_seq_len=sequence_lengths,
|
|
||||||
block_tables=block_tables,
|
|
||||||
block_size=block_size,
|
|
||||||
max_seq_len_in_batch=kv_seq_len,
|
|
||||||
output=output_tensor,
|
|
||||||
mid_output=fd_inter_tensor.mid_output,
|
|
||||||
mid_output_lse=fd_inter_tensor.mid_output_lse,
|
|
||||||
alibi_slopes=self.alibi_slopes,
|
|
||||||
sm_scale=sm_scale,
|
|
||||||
q_len=q_len,
|
q_len=q_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -388,9 +213,6 @@ class NopadBaichuanAttention(ParallelModule):
|
||||||
|
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
|
||||||
return f"qkv_weight_proj MergedLinear1D_Col: in_features={self.qkv_weight.shape[1]}x3, out_features={self.qkv_weight.shape[2]}, bias=False"
|
|
||||||
|
|
||||||
|
|
||||||
# NOTE This will cause difference as out length increases.
|
# NOTE This will cause difference as out length increases.
|
||||||
class NopadBaichuanMLP(NopadLlamaMLP):
|
class NopadBaichuanMLP(NopadLlamaMLP):
|
||||||
|
|
|
@ -16,18 +16,13 @@ from transformers.models.llama.modeling_llama import (
|
||||||
LlamaRMSNorm,
|
LlamaRMSNorm,
|
||||||
)
|
)
|
||||||
|
|
||||||
from colossalai.inference.config import InputMetaData
|
from colossalai.inference.config import InputMetaData, ModelShardInferenceConfig
|
||||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||||
|
from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData, get_attention_backend
|
||||||
|
from colossalai.inference.modeling.backends.pre_attention_backend import get_pre_attention_backend
|
||||||
|
from colossalai.inference.utils import can_use_flash_attn2
|
||||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||||
from colossalai.kernel.triton import (
|
from colossalai.kernel.triton import get_xine_cache, rms_layernorm
|
||||||
context_attention_unpadded,
|
|
||||||
copy_k_to_blocked_cache,
|
|
||||||
decoding_fused_rotary_embedding,
|
|
||||||
flash_decoding_attention,
|
|
||||||
get_xine_cache,
|
|
||||||
rms_layernorm,
|
|
||||||
rotary_embedding,
|
|
||||||
)
|
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.shardformer.layer.parallel_module import ParallelModule
|
from colossalai.shardformer.layer.parallel_module import ParallelModule
|
||||||
from colossalai.tensor.d_tensor import distribute_tensor, is_distributed_tensor
|
from colossalai.tensor.d_tensor import distribute_tensor, is_distributed_tensor
|
||||||
|
@ -36,14 +31,6 @@ inference_ops = InferenceOpsLoader().load()
|
||||||
|
|
||||||
logger = get_dist_logger(__name__)
|
logger = get_dist_logger(__name__)
|
||||||
|
|
||||||
try:
|
|
||||||
from flash_attn import flash_attn_varlen_func
|
|
||||||
|
|
||||||
use_flash_attn2 = True
|
|
||||||
except ImportError:
|
|
||||||
use_flash_attn2 = False
|
|
||||||
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
|
|
||||||
|
|
||||||
|
|
||||||
def llama_causal_lm_forward(
|
def llama_causal_lm_forward(
|
||||||
self: LlamaForCausalLM,
|
self: LlamaForCausalLM,
|
||||||
|
@ -126,8 +113,8 @@ def llama_model_forward(
|
||||||
cos_sin = (self._cos_cached[rotary_indexes], self._sin_cached[rotary_indexes])
|
cos_sin = (self._cos_cached[rotary_indexes], self._sin_cached[rotary_indexes])
|
||||||
|
|
||||||
elif use_cuda_kernel:
|
elif use_cuda_kernel:
|
||||||
if inputmetadata.dtype != torch.float32 and use_flash_attn2:
|
if can_use_flash_attn2(inputmetadata.dtype):
|
||||||
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
|
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.int32), (1, 0))
|
||||||
|
|
||||||
hidden_dim = self._cos_cached.size(-1)
|
hidden_dim = self._cos_cached.size(-1)
|
||||||
total_length = hidden_states.size(0)
|
total_length = hidden_states.size(0)
|
||||||
|
@ -238,7 +225,6 @@ def llama_decoder_layer_forward(
|
||||||
kv_seq_len=kv_seq_len,
|
kv_seq_len=kv_seq_len,
|
||||||
output_tensor=output_tensor,
|
output_tensor=output_tensor,
|
||||||
sm_scale=sm_scale,
|
sm_scale=sm_scale,
|
||||||
use_cuda_kernel=use_cuda_kernel,
|
|
||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
high_precision=high_precision,
|
high_precision=high_precision,
|
||||||
)
|
)
|
||||||
|
@ -279,7 +265,7 @@ class NopadLlamaMLP(LlamaMLP, ParallelModule):
|
||||||
mlp_dproj: ParallelModule = None,
|
mlp_dproj: ParallelModule = None,
|
||||||
process_group: ProcessGroup = None,
|
process_group: ProcessGroup = None,
|
||||||
):
|
):
|
||||||
"""A Unified Layer for
|
"""Replacement of LlamaMLP layer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config (LlamaConfig): Holding the Llama model config.
|
config (LlamaConfig): Holding the Llama model config.
|
||||||
|
@ -402,6 +388,7 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
|
||||||
attn_vproj_w: torch.Tensor = None,
|
attn_vproj_w: torch.Tensor = None,
|
||||||
attn_oproj: ParallelModule = None,
|
attn_oproj: ParallelModule = None,
|
||||||
process_group: ProcessGroup = None,
|
process_group: ProcessGroup = None,
|
||||||
|
model_shard_infer_config: ModelShardInferenceConfig = None,
|
||||||
num_heads: int = None,
|
num_heads: int = None,
|
||||||
hidden_size: int = None,
|
hidden_size: int = None,
|
||||||
num_key_value_heads: int = None,
|
num_key_value_heads: int = None,
|
||||||
|
@ -433,6 +420,9 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
|
||||||
self.rope_theta = config.rope_theta
|
self.rope_theta = config.rope_theta
|
||||||
self.is_causal = True
|
self.is_causal = True
|
||||||
|
|
||||||
|
self.attention_backend = get_attention_backend(model_shard_infer_config)
|
||||||
|
self.pre_attention_backend = get_pre_attention_backend(model_shard_infer_config)
|
||||||
|
|
||||||
if self.num_heads == self.num_key_value_heads:
|
if self.num_heads == self.num_key_value_heads:
|
||||||
qkv_weight_list = [attn_qproj_w.transpose(0, 1), attn_kproj_w.transpose(0, 1), attn_vproj_w.transpose(0, 1)]
|
qkv_weight_list = [attn_qproj_w.transpose(0, 1), attn_kproj_w.transpose(0, 1), attn_vproj_w.transpose(0, 1)]
|
||||||
self.qkv_weight = nn.Parameter(torch.stack(qkv_weight_list, dim=0))
|
self.qkv_weight = nn.Parameter(torch.stack(qkv_weight_list, dim=0))
|
||||||
|
@ -462,6 +452,7 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
|
||||||
attn_vproj_w = module.v_proj.weight
|
attn_vproj_w = module.v_proj.weight
|
||||||
assert is_distributed_tensor(attn_qproj_w), "attn_qproj_w must be dist tensor"
|
assert is_distributed_tensor(attn_qproj_w), "attn_qproj_w must be dist tensor"
|
||||||
attn_oproj = module.o_proj
|
attn_oproj = module.o_proj
|
||||||
|
model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
|
||||||
|
|
||||||
attn_layer = NopadLlamaAttention(
|
attn_layer = NopadLlamaAttention(
|
||||||
config=config,
|
config=config,
|
||||||
|
@ -471,6 +462,7 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
|
||||||
attn_vproj_w=attn_vproj_w,
|
attn_vproj_w=attn_vproj_w,
|
||||||
attn_oproj=attn_oproj,
|
attn_oproj=attn_oproj,
|
||||||
process_group=process_group,
|
process_group=process_group,
|
||||||
|
model_shard_infer_config=model_shard_infer_config,
|
||||||
num_heads=module.num_heads,
|
num_heads=module.num_heads,
|
||||||
hidden_size=module.hidden_size,
|
hidden_size=module.hidden_size,
|
||||||
num_key_value_heads=module.num_key_value_heads,
|
num_key_value_heads=module.num_key_value_heads,
|
||||||
|
@ -533,109 +525,48 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
|
||||||
|
|
||||||
block_size = k_cache.size(-2)
|
block_size = k_cache.size(-2)
|
||||||
|
|
||||||
if is_prompts:
|
attn_metadata = AttentionMetaData(
|
||||||
if not is_verifier and use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2:
|
query_states=query_states,
|
||||||
# flash attn 2 currently only supports FP16/BF16.
|
key_states=key_states,
|
||||||
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision)
|
value_states=value_states,
|
||||||
inference_ops.context_kv_cache_memcpy(
|
|
||||||
key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = flash_attn_varlen_func(
|
|
||||||
query_states,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
cu_seqlens_q=cu_seqlens,
|
|
||||||
cu_seqlens_k=cu_seqlens,
|
|
||||||
max_seqlen_q=kv_seq_len,
|
|
||||||
max_seqlen_k=kv_seq_len,
|
|
||||||
dropout_p=0.0,
|
|
||||||
softmax_scale=sm_scale,
|
|
||||||
causal=True,
|
|
||||||
)
|
|
||||||
attn_output = attn_output.view(token_nums, -1)
|
|
||||||
else:
|
|
||||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
|
||||||
attn_output = context_attention_unpadded(
|
|
||||||
q=query_states,
|
|
||||||
k=key_states,
|
|
||||||
v=value_states,
|
|
||||||
k_cache=k_cache,
|
k_cache=k_cache,
|
||||||
v_cache=v_cache,
|
v_cache=v_cache,
|
||||||
context_lengths=sequence_lengths,
|
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
output=output_tensor,
|
kv_seq_len=kv_seq_len,
|
||||||
max_seq_len=kv_seq_len,
|
sequence_lengths=sequence_lengths,
|
||||||
sm_scale=sm_scale,
|
sm_scale=sm_scale,
|
||||||
use_new_kcache_layout=use_cuda_kernel,
|
alibi_slopes=None,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
output_tensor=output_tensor,
|
||||||
|
use_spec_dec=is_verifier,
|
||||||
|
use_alibi_attn=False,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
|
if is_prompts: # prefilling stage
|
||||||
|
self.pre_attention_backend.prefill(
|
||||||
|
attn_metadata,
|
||||||
|
cos=cos_sin[0],
|
||||||
|
sin=cos_sin[1],
|
||||||
|
high_precision=high_precision,
|
||||||
|
)
|
||||||
|
attn_output = self.attention_backend.prefill(
|
||||||
|
attn_metadata,
|
||||||
|
token_nums=token_nums,
|
||||||
|
)
|
||||||
|
else: # decoding stage
|
||||||
q_len = tokens_to_verify + 1 if is_verifier else 1
|
q_len = tokens_to_verify + 1 if is_verifier else 1
|
||||||
|
|
||||||
if use_cuda_kernel:
|
self.pre_attention_backend.decode(
|
||||||
inference_ops.rotary_embedding_and_cache_copy(
|
attn_metadata,
|
||||||
query_states,
|
cos=cos_sin[0],
|
||||||
key_states,
|
sin=cos_sin[1],
|
||||||
value_states,
|
q_len=q_len,
|
||||||
cos_sin[0],
|
|
||||||
cos_sin[1],
|
|
||||||
k_cache,
|
|
||||||
v_cache,
|
|
||||||
sequence_lengths,
|
|
||||||
block_tables,
|
|
||||||
high_precision,
|
|
||||||
)
|
)
|
||||||
inference_ops.flash_decoding_attention(
|
attn_output = self.attention_backend.decode(
|
||||||
output_tensor,
|
attn_metadata,
|
||||||
query_states,
|
fd_inter_tensor=fd_inter_tensor,
|
||||||
k_cache,
|
num_key_value_groups=self.num_key_value_groups,
|
||||||
v_cache,
|
|
||||||
sequence_lengths,
|
|
||||||
block_tables,
|
|
||||||
block_size,
|
|
||||||
kv_seq_len,
|
|
||||||
fd_inter_tensor.mid_output,
|
|
||||||
fd_inter_tensor.exp_sums,
|
|
||||||
fd_inter_tensor.max_logits,
|
|
||||||
None,
|
|
||||||
sm_scale,
|
|
||||||
)
|
|
||||||
attn_output = output_tensor
|
|
||||||
else:
|
|
||||||
if is_verifier:
|
|
||||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
|
||||||
copy_k_to_blocked_cache(
|
|
||||||
key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
|
|
||||||
)
|
|
||||||
copy_k_to_blocked_cache(
|
|
||||||
value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
decoding_fused_rotary_embedding(
|
|
||||||
query_states,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
cos_sin[0],
|
|
||||||
cos_sin[1],
|
|
||||||
k_cache,
|
|
||||||
v_cache,
|
|
||||||
block_tables,
|
|
||||||
sequence_lengths,
|
|
||||||
)
|
|
||||||
attn_output = flash_decoding_attention(
|
|
||||||
q=query_states,
|
|
||||||
k_cache=k_cache,
|
|
||||||
v_cache=v_cache,
|
|
||||||
kv_seq_len=sequence_lengths,
|
|
||||||
block_tables=block_tables,
|
|
||||||
block_size=block_size,
|
|
||||||
max_seq_len_in_batch=kv_seq_len,
|
|
||||||
output=output_tensor,
|
|
||||||
mid_output=fd_inter_tensor.mid_output,
|
|
||||||
mid_output_lse=fd_inter_tensor.mid_output_lse,
|
|
||||||
sm_scale=sm_scale,
|
|
||||||
kv_group_num=self.num_key_value_groups,
|
|
||||||
q_len=q_len,
|
q_len=q_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,5 @@
|
||||||
from colossalai.inference.config import RPC_PARAM
|
from colossalai.inference.config import RPC_PARAM
|
||||||
from colossalai.inference.modeling.layers.baichuan_tp_linear import (
|
from colossalai.inference.modeling.layers.baichuan_tp_linear import BaichuanLMHeadLinear1D_Col
|
||||||
BaichuanLMHeadLinear1D_Col,
|
|
||||||
BaichuanWpackLinear1D_Col,
|
|
||||||
)
|
|
||||||
from colossalai.inference.modeling.models.nopadding_baichuan import (
|
from colossalai.inference.modeling.models.nopadding_baichuan import (
|
||||||
NopadBaichuanAttention,
|
NopadBaichuanAttention,
|
||||||
NopadBaichuanMLP,
|
NopadBaichuanMLP,
|
||||||
|
@ -14,7 +11,7 @@ from colossalai.inference.modeling.models.nopadding_llama import (
|
||||||
llama_model_forward,
|
llama_model_forward,
|
||||||
)
|
)
|
||||||
from colossalai.inference.utils import init_to_get_rotary
|
from colossalai.inference.utils import init_to_get_rotary
|
||||||
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
|
from colossalai.shardformer.layer import FusedLinear1D_Col, Linear1D_Col, Linear1D_Row
|
||||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
|
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
|
||||||
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
||||||
|
|
||||||
|
@ -60,8 +57,7 @@ class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM):
|
||||||
target_module=NopadBaichuanMLP,
|
target_module=NopadBaichuanMLP,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.W_pack",
|
suffix="self_attn.W_pack", target_module=FusedLinear1D_Col, kwargs={"n_fused": 3}
|
||||||
target_module=BaichuanWpackLinear1D_Col,
|
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.o_proj",
|
suffix="self_attn.o_proj",
|
||||||
|
@ -70,6 +66,9 @@ class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM):
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn",
|
suffix="self_attn",
|
||||||
target_module=NopadBaichuanAttention,
|
target_module=NopadBaichuanAttention,
|
||||||
|
kwargs={
|
||||||
|
"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"],
|
||||||
|
},
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -72,6 +72,9 @@ class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM):
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn",
|
suffix="self_attn",
|
||||||
target_module=NopadLlamaAttention,
|
target_module=NopadLlamaAttention,
|
||||||
|
kwargs={
|
||||||
|
"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"],
|
||||||
|
},
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -46,6 +46,7 @@ class GlideInput:
|
||||||
large_k_cache: torch.Tensor = None
|
large_k_cache: torch.Tensor = None
|
||||||
large_v_cache: torch.Tensor = None
|
large_v_cache: torch.Tensor = None
|
||||||
sequence_lengths: torch.Tensor = None
|
sequence_lengths: torch.Tensor = None
|
||||||
|
n_spec_tokens: int = 5
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def glimpse_ready(self):
|
def glimpse_ready(self):
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""
|
"""
|
||||||
Utils for model inference
|
Utils for model inference
|
||||||
"""
|
"""
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -9,8 +10,11 @@ from typing import Optional, Tuple
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.testing import free_port
|
from colossalai.testing import free_port
|
||||||
|
|
||||||
|
logger = get_dist_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def init_to_get_rotary(self, base=10000, use_elem=False):
|
def init_to_get_rotary(self, base=10000, use_elem=False):
|
||||||
"""
|
"""
|
||||||
|
@ -113,3 +117,44 @@ def find_available_ports(num: int):
|
||||||
print(f"An OS error occurred: {e}")
|
print(f"An OS error occurred: {e}")
|
||||||
raise RuntimeError("Error finding available ports")
|
raise RuntimeError("Error finding available ports")
|
||||||
return free_ports
|
return free_ports
|
||||||
|
|
||||||
|
|
||||||
|
def get_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Alibi slopes calculation adapted from https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/models/bloom/modeling_bloom.py#L57
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_heads (int): The number of attention heads.
|
||||||
|
device (torch.device): The device to use.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The Alibi slopes.
|
||||||
|
"""
|
||||||
|
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
|
||||||
|
base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, device=device)
|
||||||
|
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32, device=device)
|
||||||
|
slopes = torch.pow(base, powers)
|
||||||
|
if closest_power_of_2 != num_heads:
|
||||||
|
extra_base = torch.tensor(
|
||||||
|
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, device=device
|
||||||
|
)
|
||||||
|
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
|
||||||
|
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32, device=device)
|
||||||
|
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
||||||
|
return slopes
|
||||||
|
|
||||||
|
|
||||||
|
def can_use_flash_attn2(dtype: torch.dtype) -> bool:
|
||||||
|
"""
|
||||||
|
Check flash attention2 availability.
|
||||||
|
"""
|
||||||
|
if dtype not in (torch.float16, torch.bfloat16):
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
from flash_attn import flash_attn_varlen_func # noqa
|
||||||
|
|
||||||
|
return True
|
||||||
|
except ImportError:
|
||||||
|
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
|
||||||
|
return False
|
||||||
|
|
|
@ -45,7 +45,10 @@ def launch(
|
||||||
backend = cur_accelerator.communication_backend
|
backend = cur_accelerator.communication_backend
|
||||||
|
|
||||||
# init default process group
|
# init default process group
|
||||||
|
if ":" in host: # IPv6
|
||||||
init_method = f"tcp://[{host}]:{port}"
|
init_method = f"tcp://[{host}]:{port}"
|
||||||
|
else: # IPv4
|
||||||
|
init_method = f"tcp://{host}:{port}"
|
||||||
dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)
|
dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)
|
||||||
|
|
||||||
# set cuda device
|
# set cuda device
|
||||||
|
|
|
@ -50,7 +50,7 @@ def get_pad_info(padding_mask: torch.Tensor) -> Tuple[int, torch.Tensor, torch.T
|
||||||
seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
|
seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
|
||||||
indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
|
indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
|
||||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
||||||
return max_seqlen_in_batch, cu_seqlens, indices
|
return max_seqlen_in_batch, cu_seqlens, indices
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -140,32 +140,29 @@ class RMSNorm(BaseLayerNorm):
|
||||||
|
|
||||||
class LayerNorm(BaseLayerNorm):
|
class LayerNorm(BaseLayerNorm):
|
||||||
r"""
|
r"""
|
||||||
This is a wrapper around the torch.nn.LayerNorm. It is meant to be used only with the from_native_module interface.
|
This is a wrapper around native LayerNorm. It is meant to be used only with the from_native_module interface.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"LayerNorm is not implemented as a physical class. "
|
"LayerNorm is not implemented as a physical class. "
|
||||||
"It is meant to be used only with the from_native_module interface to convert a native pytorch layer norm module to colossalai layer norm module."
|
"It is meant to be used only with the from_native_module interface to convert a native LayerNorm module to colossalai layer norm module."
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
|
def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
|
||||||
r"""
|
r"""
|
||||||
Convert a native pytorch layer norm module to colossalai layer norm module,
|
Convert a native LayerNorm module to colossalai layer norm module,
|
||||||
and optionally marking parameters for gradient aggregation.
|
and optionally marking parameters for gradient aggregation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
|
module (nn.Module): The native LayerNorm module to be converted.
|
||||||
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
|
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
nn.Module: The LayerNorm module.
|
nn.Module: The colossalai LayerNorm module.
|
||||||
|
|
||||||
Raises:
|
|
||||||
AssertionError: If the provided module is not an instance of nn.LayerNorm.
|
|
||||||
"""
|
"""
|
||||||
assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm."
|
|
||||||
|
|
||||||
LazyInitContext.materialize(module)
|
LazyInitContext.materialize(module)
|
||||||
|
|
||||||
|
@ -174,6 +171,7 @@ class LayerNorm(BaseLayerNorm):
|
||||||
# aggregation of these gradients is necessary during backpropagation.
|
# aggregation of these gradients is necessary during backpropagation.
|
||||||
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
|
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
|
||||||
SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight)
|
SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight)
|
||||||
|
if module.bias is not None:
|
||||||
SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias)
|
SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias)
|
||||||
|
|
||||||
return module
|
return module
|
||||||
|
@ -187,31 +185,29 @@ class FusedLayerNorm(BaseLayerNorm):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"FusedLayerNorm is not implemented as a physical class. "
|
"FusedLayerNorm is not implemented as a physical class. "
|
||||||
"It is meant to be used only with the from_native_module interface convert a native pytorch layer norm module to FusedLayerNorm module provided by apex."
|
"It is meant to be used only with the from_native_module interface convert a native LayerNorm module to FusedLayerNorm module provided by apex."
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
|
def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
|
||||||
r"""
|
r"""
|
||||||
Convert a native pytorch layer norm module to FusedLayerNorm module provided by apex,
|
Convert a native LayerNorm module to FusedLayerNorm module provided by apex,
|
||||||
and optionally marking parameters for gradient aggregation.
|
and optionally marking parameters for gradient aggregation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
|
module (nn.Module): The native LayerNorm module to be converted.
|
||||||
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
|
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
nn.Module: Union[FastLayerNorm, FusedLayerNorm].
|
nn.Module: Union[FastLayerNorm, FusedLayerNorm].
|
||||||
|
|
||||||
Raises:
|
|
||||||
AssertionError: If the provided module is not an instance of nn.LayerNorm.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
LazyInitContext.materialize(module)
|
LazyInitContext.materialize(module)
|
||||||
# get the attributes of the module
|
# get the attributes of the module
|
||||||
normalized_shape = module.normalized_shape
|
normalized_shape = getattr(module, "normalized_shape", module.weight.shape[0])
|
||||||
eps = module.eps
|
eps = module.variance_epsilon if hasattr(module, "variance_epsilon") else module.eps
|
||||||
elementwise_affine = module.elementwise_affine
|
elementwise_affine = getattr(module, "elementwise_affine", True)
|
||||||
dtype = module.weight.dtype
|
dtype = module.weight.dtype
|
||||||
device = module.weight.device
|
device = module.weight.device
|
||||||
|
|
||||||
|
@ -229,7 +225,7 @@ class FusedLayerNorm(BaseLayerNorm):
|
||||||
ApexFusedLayerNorm = FusedLayerNormWithHook
|
ApexFusedLayerNorm = FusedLayerNormWithHook
|
||||||
except NameError:
|
except NameError:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Please install Apex from source to use fused kernels, or set self.enable_fused_normalization = False. Using vanilla layernorm instead."
|
"Please install Apex from source to use fused kernels, or set self.enable_fused_normalization = False. Using native layernorm instead."
|
||||||
)
|
)
|
||||||
return module
|
return module
|
||||||
|
|
||||||
|
@ -237,6 +233,7 @@ class FusedLayerNorm(BaseLayerNorm):
|
||||||
ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device)
|
ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device)
|
||||||
)
|
)
|
||||||
layernorm.weight = module.weight
|
layernorm.weight = module.weight
|
||||||
|
if module.bias is not None:
|
||||||
layernorm.bias = module.bias
|
layernorm.bias = module.bias
|
||||||
|
|
||||||
if sp_partial_derived:
|
if sp_partial_derived:
|
||||||
|
|
|
@ -475,7 +475,10 @@ class BloomPipelineForwards:
|
||||||
sequence_lengths = -1
|
sequence_lengths = -1
|
||||||
else:
|
else:
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
|
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||||
|
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||||
|
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||||
|
sequence_lengths = sequence_lengths.to(logits.device)
|
||||||
else:
|
else:
|
||||||
sequence_lengths = -1
|
sequence_lengths = -1
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|
|
@ -0,0 +1,692 @@
|
||||||
|
import math
|
||||||
|
import warnings
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import CrossEntropyLoss
|
||||||
|
from transformers.cache_utils import Cache, DynamicCache
|
||||||
|
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||||
|
from transformers.models.cohere.modeling_cohere import (
|
||||||
|
CohereForCausalLM,
|
||||||
|
CohereModel,
|
||||||
|
StaticCache,
|
||||||
|
apply_rotary_pos_emb,
|
||||||
|
repeat_kv,
|
||||||
|
)
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
from colossalai.shardformer.layer._operation import (
|
||||||
|
all_to_all_comm,
|
||||||
|
gather_forward_split_backward,
|
||||||
|
split_forward_gather_backward,
|
||||||
|
)
|
||||||
|
from colossalai.shardformer.shard import ShardConfig
|
||||||
|
|
||||||
|
from ..layer import ColoAttention, cross_entropy_1d
|
||||||
|
|
||||||
|
|
||||||
|
class CommandPipelineForwards:
|
||||||
|
"""
|
||||||
|
This class serves as a micro library for forward function substitution of Command models
|
||||||
|
under pipeline setting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def command_model_forward(
|
||||||
|
self: CohereModel,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
stage_index: Optional[List[int]] = None,
|
||||||
|
shard_config: ShardConfig = None,
|
||||||
|
):
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
logger.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with pipeline parallelism. Setting `use_cache=False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# retrieve input_ids and inputs_embeds
|
||||||
|
if stage_manager.is_first_stage():
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
batch_size, seq_length = input_ids.shape[:2]
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
batch_size, seq_length, _ = inputs_embeds.shape[:2]
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
else:
|
||||||
|
input_shape = hidden_states.shape[:-1]
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
|
device = hidden_states.device
|
||||||
|
|
||||||
|
past_seen_tokens = 0
|
||||||
|
if use_cache: # kept for BC (cache positions)
|
||||||
|
if not isinstance(past_key_values, StaticCache):
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length()
|
||||||
|
if cache_position is None:
|
||||||
|
if isinstance(past_key_values, StaticCache):
|
||||||
|
raise ValueError("cache_position is a required argument when using StaticCache.")
|
||||||
|
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=device)
|
||||||
|
|
||||||
|
seq_length_with_past = seq_length + past_seen_tokens
|
||||||
|
|
||||||
|
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||||
|
if output_attentions:
|
||||||
|
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||||
|
output_attentions = False
|
||||||
|
if output_hidden_states:
|
||||||
|
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||||
|
output_hidden_states = False
|
||||||
|
if use_cache:
|
||||||
|
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
|
# embed positions, for the first stage, hidden_states is the input embeddings,
|
||||||
|
# for the other stages, hidden_states is the output of the previous stage
|
||||||
|
if shard_config.enable_flash_attention:
|
||||||
|
# in this case, attention_mask is a dict rather than a tensor
|
||||||
|
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
|
||||||
|
attention_mask = ColoAttention.prepare_attn_kwargs(
|
||||||
|
mask_shape,
|
||||||
|
hidden_states.dtype,
|
||||||
|
hidden_states.device,
|
||||||
|
q_padding_mask=attention_mask,
|
||||||
|
is_causal=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position)
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training and use_cache:
|
||||||
|
if use_cache:
|
||||||
|
logger.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
# decoder layers
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
next_decoder_cache = None
|
||||||
|
|
||||||
|
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||||
|
num_ckpt_layers = 0
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
num_ckpt_layers = end_idx - start_idx
|
||||||
|
# TODO: We can replace `gradient_checkpointing_enable` fn and initialize a gradient_checkpointing (List[bool]) for each layer
|
||||||
|
if shard_config.gradient_checkpoint_config is not None:
|
||||||
|
num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
|
||||||
|
stage=stage_manager.stage,
|
||||||
|
num_stages=stage_manager.num_stages,
|
||||||
|
num_layers=end_idx - start_idx,
|
||||||
|
model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0),
|
||||||
|
num_model_chunks=stage_manager.num_model_chunks,
|
||||||
|
)
|
||||||
|
assert num_ckpt_layers <= end_idx - start_idx
|
||||||
|
|
||||||
|
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
if idx - start_idx < num_ckpt_layers:
|
||||||
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
|
decoder_layer.__call__,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
|
past_key_values,
|
||||||
|
output_attentions,
|
||||||
|
use_cache,
|
||||||
|
cache_position,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_values,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
|
if stage_manager.is_last_stage():
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
# add hidden states from the last decoder layer
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
|
if stage_manager.is_last_stage():
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(
|
||||||
|
v
|
||||||
|
for v in [
|
||||||
|
hidden_states,
|
||||||
|
next_cache,
|
||||||
|
all_hidden_states,
|
||||||
|
all_self_attns,
|
||||||
|
]
|
||||||
|
if v is not None
|
||||||
|
)
|
||||||
|
return BaseModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
|
)
|
||||||
|
# always return dict for imediate stage
|
||||||
|
return {"hidden_states": hidden_states}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def command_for_causal_lm_forward(
|
||||||
|
self: CohereForCausalLM,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
stage_index: Optional[List[int]] = None,
|
||||||
|
shard_config: ShardConfig = None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer, CohereForCausalLM
|
||||||
|
|
||||||
|
>>> model = CohereForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||||
|
|
||||||
|
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||||
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||||
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||||
|
```"""
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||||
|
if output_attentions:
|
||||||
|
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||||
|
output_attentions = False
|
||||||
|
if output_hidden_states:
|
||||||
|
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||||
|
output_hidden_states = False
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = CommandPipelineForwards.command_model_forward(
|
||||||
|
self.model,
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
stage_manager=stage_manager,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
stage_index=stage_index,
|
||||||
|
shard_config=shard_config,
|
||||||
|
)
|
||||||
|
past_key_values = None
|
||||||
|
|
||||||
|
if stage_manager.is_last_stage():
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
logits = logits * self.logit_scale
|
||||||
|
logits = logits.float()
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
# Enable model parallelism
|
||||||
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
|
||||||
|
new_vocab_size = logits.shape[-1]
|
||||||
|
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||||
|
loss = cross_entropy_1d(
|
||||||
|
shift_logits,
|
||||||
|
shift_labels,
|
||||||
|
process_group=shard_config.tensor_parallel_process_group,
|
||||||
|
vocab_size=self.lm_head.out_features,
|
||||||
|
dtype=self.model.dtype,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||||
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states = outputs.get("hidden_states")
|
||||||
|
return {"hidden_states": hidden_states}
|
||||||
|
|
||||||
|
|
||||||
|
def get_command_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
||||||
|
if sp_mode is not None:
|
||||||
|
assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode"
|
||||||
|
assert (sp_size is not None) and (
|
||||||
|
sp_group is not None
|
||||||
|
), "Must specify sp_size and sp_group for sequence parallel"
|
||||||
|
if "padding_mask" in kwargs:
|
||||||
|
warnings.warn(
|
||||||
|
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||||
|
)
|
||||||
|
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
# sp: modify sp_len when sequence parallel mode is ring
|
||||||
|
if sp_mode in ["split_gather", "ring"]:
|
||||||
|
q_len *= sp_size
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
# sp: all-to-all comminucation when introducing sequence parallel
|
||||||
|
if sp_mode == "all_to_all":
|
||||||
|
query_states = all_to_all_comm(query_states, sp_group)
|
||||||
|
key_states = all_to_all_comm(key_states, sp_group)
|
||||||
|
value_states = all_to_all_comm(value_states, sp_group)
|
||||||
|
bsz, q_len, _ = query_states.size()
|
||||||
|
|
||||||
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
if self.layer_idx is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||||
|
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||||
|
"with a layer index."
|
||||||
|
)
|
||||||
|
|
||||||
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
|
|
||||||
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||||
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
if shard_config.enable_flash_attention:
|
||||||
|
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
|
||||||
|
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
|
||||||
|
else:
|
||||||
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||||
|
f" {attn_weights.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||||
|
)
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
# sp: all-to-all comminucation when introducing sequence parallel
|
||||||
|
if sp_mode == "all_to_all":
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
|
||||||
|
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2)
|
||||||
|
else:
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
return forward
|
||||||
|
|
||||||
|
|
||||||
|
def get_command_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# retrieve input_ids and inputs_embeds
|
||||||
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
|
raise ValueError(
|
||||||
|
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
||||||
|
)
|
||||||
|
|
||||||
|
if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training:
|
||||||
|
if use_cache:
|
||||||
|
logger.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
past_seen_tokens = 0
|
||||||
|
seq_len = inputs_embeds.shape[1]
|
||||||
|
if use_cache: # kept for BC (cache positions)
|
||||||
|
if not isinstance(past_key_values, StaticCache):
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length()
|
||||||
|
if cache_position is None:
|
||||||
|
if isinstance(past_key_values, StaticCache):
|
||||||
|
raise ValueError("cache_position is a required argument when using StaticCache.")
|
||||||
|
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_len, device=inputs_embeds.device)
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
|
# in this case, attention_mask is a dict rather than a tensor
|
||||||
|
if shard_config.enable_flash_attention:
|
||||||
|
mask_shape = (inputs_embeds.shape[0], 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len)
|
||||||
|
attention_mask = ColoAttention.prepare_attn_kwargs(
|
||||||
|
mask_shape,
|
||||||
|
inputs_embeds.dtype,
|
||||||
|
inputs_embeds.device,
|
||||||
|
q_padding_mask=attention_mask,
|
||||||
|
is_causal=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
|
||||||
|
|
||||||
|
if sp_mode in ["ring", "split_gather"]:
|
||||||
|
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
|
||||||
|
elif sp_mode == "all_to_all":
|
||||||
|
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
# decoder layers
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
next_decoder_cache = None
|
||||||
|
|
||||||
|
for decoder_layer in self.layers:
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
|
decoder_layer.__call__,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
|
past_key_values,
|
||||||
|
output_attentions,
|
||||||
|
use_cache,
|
||||||
|
cache_position,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_values,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
if sp_mode == "ring" or sp_mode == "split_gather":
|
||||||
|
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
|
||||||
|
elif sp_mode == "all_to_all":
|
||||||
|
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
|
||||||
|
|
||||||
|
# add hidden states from the last decoder layer
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
next_cache = None
|
||||||
|
if use_cache:
|
||||||
|
next_cache = (
|
||||||
|
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
|
||||||
|
)
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||||
|
|
||||||
|
return BaseModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
|
)
|
||||||
|
|
||||||
|
return forward
|
||||||
|
|
||||||
|
|
||||||
|
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||||
|
from transformers import CohereForCausalLM
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self: CohereForCausalLM,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer, CohereForCausalLM
|
||||||
|
|
||||||
|
>>> model = CohereForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||||
|
|
||||||
|
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||||
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||||
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||||
|
```"""
|
||||||
|
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
logits = logits * self.logit_scale
|
||||||
|
logits = logits.float()
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
# Enable model parallelism
|
||||||
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
new_vocab_size = logits.shape[-1]
|
||||||
|
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||||
|
loss = cross_entropy_1d(
|
||||||
|
shift_logits,
|
||||||
|
shift_labels,
|
||||||
|
process_group=shard_config.tensor_parallel_process_group,
|
||||||
|
vocab_size=self.lm_head.out_features,
|
||||||
|
dtype=self.model.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
return forward
|
|
@ -291,18 +291,17 @@ class FalconPipelineForwards:
|
||||||
if attention_mask_2d is None:
|
if attention_mask_2d is None:
|
||||||
attention_mask = alibi / math.sqrt(self.config.hidden_size // self.num_heads)
|
attention_mask = alibi / math.sqrt(self.config.hidden_size // self.num_heads)
|
||||||
else:
|
else:
|
||||||
|
min_dtype = torch.finfo(alibi.dtype).min
|
||||||
attention_mask = torch.masked_fill(
|
attention_mask = torch.masked_fill(
|
||||||
alibi / math.sqrt(self.config.hidden_size // self.num_heads),
|
alibi / math.sqrt(self.config.hidden_size // self.num_heads),
|
||||||
attention_mask < -1,
|
attention_mask < -1,
|
||||||
torch.finfo(alibi.dtype).min,
|
min_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
|
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
|
||||||
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
|
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
|
||||||
if seq_length > 1:
|
if seq_length > 1 and attention_mask.device.type == "cuda":
|
||||||
attention_mask = AttentionMaskConverter._unmask_unattended(
|
attention_mask = AttentionMaskConverter._unmask_unattended(attention_mask, min_dtype=min_dtype)
|
||||||
attention_mask, attention_mask_2d, unmasked_value=0.0
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case.
|
# PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case.
|
||||||
attention_mask = _prepare_4d_causal_attention_mask(
|
attention_mask = _prepare_4d_causal_attention_mask(
|
||||||
|
@ -543,7 +542,10 @@ class FalconPipelineForwards:
|
||||||
sequence_lengths = -1
|
sequence_lengths = -1
|
||||||
else:
|
else:
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(dim=-1) - 1).to(logits.device)
|
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||||
|
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||||
|
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||||
|
sequence_lengths = sequence_lengths.to(logits.device)
|
||||||
else:
|
else:
|
||||||
sequence_lengths = -1
|
sequence_lengths = -1
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|
|
@ -738,7 +738,10 @@ class GPT2PipelineForwards:
|
||||||
sequence_lengths = -1
|
sequence_lengths = -1
|
||||||
else:
|
else:
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
|
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||||
|
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||||
|
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||||
|
sequence_lengths = sequence_lengths.to(logits.device)
|
||||||
else:
|
else:
|
||||||
sequence_lengths = -1
|
sequence_lengths = -1
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
|
|
|
@ -32,6 +32,7 @@ def _get_attention_mask(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]],
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]],
|
||||||
attention_mask: Optional[torch.FloatTensor],
|
attention_mask: Optional[torch.FloatTensor],
|
||||||
|
use_flash_attention_2: bool = False,
|
||||||
) -> Optional[Union[torch.Tensor, dict]]:
|
) -> Optional[Union[torch.Tensor, dict]]:
|
||||||
batch_size, seq_len = hidden_states.shape[:2]
|
batch_size, seq_len = hidden_states.shape[:2]
|
||||||
past_key_values_length = 0
|
past_key_values_length = 0
|
||||||
|
@ -47,7 +48,7 @@ def _get_attention_mask(
|
||||||
attention_mask,
|
attention_mask,
|
||||||
is_causal=True,
|
is_causal=True,
|
||||||
)
|
)
|
||||||
elif attention_mask is not None:
|
elif use_flash_attention_2 and attention_mask is not None:
|
||||||
if batch_size <= 0:
|
if batch_size <= 0:
|
||||||
raise ValueError("batch_size has to be defined and > 0")
|
raise ValueError("batch_size has to be defined and > 0")
|
||||||
attention_mask = attention_mask.view(batch_size, -1)
|
attention_mask = attention_mask.view(batch_size, -1)
|
||||||
|
@ -162,7 +163,9 @@ class GPTJPipelineForwards:
|
||||||
|
|
||||||
output_shape = input_shape + (hidden_states.size(-1),)
|
output_shape = input_shape + (hidden_states.size(-1),)
|
||||||
|
|
||||||
attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask)
|
attention_mask = _get_attention_mask(
|
||||||
|
self, shard_config, hidden_states, past_key_values, attention_mask, self._use_flash_attention_2
|
||||||
|
)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
if use_cache:
|
if use_cache:
|
||||||
|
@ -419,7 +422,10 @@ class GPTJPipelineForwards:
|
||||||
sequence_lengths = -1
|
sequence_lengths = -1
|
||||||
else:
|
else:
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
|
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||||
|
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||||
|
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||||
|
sequence_lengths = sequence_lengths.to(logits.device)
|
||||||
else:
|
else:
|
||||||
sequence_lengths = -1
|
sequence_lengths = -1
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
|
@ -712,7 +718,9 @@ def gptj_model_forward_for_flash_attention(shard_config: ShardConfig):
|
||||||
|
|
||||||
hidden_states = self.drop(hidden_states)
|
hidden_states = self.drop(hidden_states)
|
||||||
|
|
||||||
attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask)
|
attention_mask = _get_attention_mask(
|
||||||
|
self, shard_config, hidden_states, past_key_values, attention_mask, self._use_flash_attention_2
|
||||||
|
)
|
||||||
|
|
||||||
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
|
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
|
||||||
|
|
||||||
|
@ -886,7 +894,9 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||||
hidden_states = self.drop(hidden_states)
|
hidden_states = self.drop(hidden_states)
|
||||||
|
|
||||||
output_shape = input_shape + (hidden_states.size(-1),)
|
output_shape = input_shape + (hidden_states.size(-1),)
|
||||||
attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask)
|
attention_mask = _get_attention_mask(
|
||||||
|
self, shard_config, hidden_states, past_key_values, attention_mask, self._use_flash_attention_2
|
||||||
|
)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
if use_cache:
|
if use_cache:
|
||||||
|
|
|
@ -7,11 +7,7 @@ import torch.nn.functional as F
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
from transformers.cache_utils import Cache
|
from transformers.cache_utils import Cache, DynamicCache
|
||||||
from transformers.modeling_attn_mask_utils import (
|
|
||||||
_prepare_4d_causal_attention_mask,
|
|
||||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
|
||||||
)
|
|
||||||
from transformers.modeling_outputs import (
|
from transformers.modeling_outputs import (
|
||||||
BaseModelOutputWithPast,
|
BaseModelOutputWithPast,
|
||||||
CausalLMOutputWithPast,
|
CausalLMOutputWithPast,
|
||||||
|
@ -21,6 +17,7 @@ from transformers.models.llama.modeling_llama import (
|
||||||
LlamaForCausalLM,
|
LlamaForCausalLM,
|
||||||
LlamaForSequenceClassification,
|
LlamaForSequenceClassification,
|
||||||
LlamaModel,
|
LlamaModel,
|
||||||
|
StaticCache,
|
||||||
apply_rotary_pos_emb,
|
apply_rotary_pos_emb,
|
||||||
repeat_kv,
|
repeat_kv,
|
||||||
)
|
)
|
||||||
|
@ -55,6 +52,7 @@ class LlamaPipelineForwards:
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
stage_manager: Optional[PipelineStageManager] = None,
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
hidden_states: Optional[torch.FloatTensor] = None,
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
stage_index: Optional[List[int]] = None,
|
stage_index: Optional[List[int]] = None,
|
||||||
|
@ -67,6 +65,11 @@ class LlamaPipelineForwards:
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
)
|
)
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
if use_cache:
|
||||||
|
logger.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with pipeline parallelism. Setting `use_cache=False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
@ -83,14 +86,24 @@ class LlamaPipelineForwards:
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
else:
|
else:
|
||||||
input_shape = hidden_states.shape[:-1]
|
input_shape = hidden_states.shape[:-1]
|
||||||
batch_size, seq_length = input_shape
|
batch_size, seq_length = input_shape
|
||||||
device = hidden_states.device
|
device = hidden_states.device
|
||||||
|
|
||||||
seq_length_with_past = seq_length
|
past_seen_tokens = 0
|
||||||
past_key_values_length = 0
|
if use_cache: # kept for BC (cache positions)
|
||||||
|
if not isinstance(past_key_values, StaticCache):
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length()
|
||||||
|
if cache_position is None:
|
||||||
|
if isinstance(past_key_values, StaticCache):
|
||||||
|
raise ValueError("cache_position is a required argument when using StaticCache.")
|
||||||
|
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=device)
|
||||||
|
|
||||||
|
seq_length_with_past = seq_length + past_seen_tokens
|
||||||
|
|
||||||
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
|
@ -103,18 +116,8 @@ class LlamaPipelineForwards:
|
||||||
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
|
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
|
||||||
if past_key_values is not None:
|
|
||||||
past_key_values_length = past_key_values[0][0].shape[2]
|
|
||||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = torch.arange(
|
position_ids = cache_position.unsqueeze(0)
|
||||||
past_key_values_length,
|
|
||||||
seq_length + past_key_values_length,
|
|
||||||
dtype=torch.long,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
position_ids = position_ids.unsqueeze(0)
|
|
||||||
|
|
||||||
# embed positions, for the first stage, hidden_states is the input embeddings,
|
# embed positions, for the first stage, hidden_states is the input embeddings,
|
||||||
# for the other stages, hidden_states is the output of the previous stage
|
# for the other stages, hidden_states is the output of the previous stage
|
||||||
|
@ -129,28 +132,9 @@ class LlamaPipelineForwards:
|
||||||
is_causal=True,
|
is_causal=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if self._use_flash_attention_2:
|
attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position)
|
||||||
# 2d mask is passed through the layers
|
|
||||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
|
||||||
elif self._use_sdpa and not output_attentions:
|
|
||||||
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
|
||||||
# the manual implementation that requires a 4D causal mask in all cases.
|
|
||||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
|
||||||
attention_mask,
|
|
||||||
(batch_size, seq_length),
|
|
||||||
inputs_embeds,
|
|
||||||
past_key_values_length,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# 4d mask is passed through the layers
|
|
||||||
attention_mask = _prepare_4d_causal_attention_mask(
|
|
||||||
attention_mask,
|
|
||||||
(batch_size, seq_length),
|
|
||||||
hidden_states,
|
|
||||||
past_key_values_length,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training and use_cache:
|
||||||
if use_cache:
|
if use_cache:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
|
@ -190,6 +174,7 @@ class LlamaPipelineForwards:
|
||||||
past_key_values,
|
past_key_values,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
use_cache,
|
use_cache,
|
||||||
|
cache_position,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
|
@ -199,6 +184,7 @@ class LlamaPipelineForwards:
|
||||||
past_key_value=past_key_values,
|
past_key_value=past_key_values,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
@ -249,6 +235,7 @@ class LlamaPipelineForwards:
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
stage_manager: Optional[PipelineStageManager] = None,
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
hidden_states: Optional[torch.FloatTensor] = None,
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
stage_index: Optional[List[int]] = None,
|
stage_index: Optional[List[int]] = None,
|
||||||
|
@ -306,6 +293,7 @@ class LlamaPipelineForwards:
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
stage_manager=stage_manager,
|
stage_manager=stage_manager,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
stage_index=stage_index,
|
stage_index=stage_index,
|
||||||
|
@ -368,6 +356,7 @@ class LlamaPipelineForwards:
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
stage_manager: Optional[PipelineStageManager] = None,
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
hidden_states: Optional[torch.FloatTensor] = None,
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
stage_index: Optional[List[int]] = None,
|
stage_index: Optional[List[int]] = None,
|
||||||
|
@ -401,6 +390,7 @@ class LlamaPipelineForwards:
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
stage_manager=stage_manager,
|
stage_manager=stage_manager,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
stage_index=stage_index,
|
stage_index=stage_index,
|
||||||
|
@ -470,33 +460,50 @@ class LlamaPipelineForwards:
|
||||||
return {"hidden_states": hidden_states}
|
return {"hidden_states": hidden_states}
|
||||||
|
|
||||||
|
|
||||||
def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size):
|
def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
|
||||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
|
||||||
|
|
||||||
try:
|
|
||||||
from transformers.models.llama.modeling_llama import repeat_kv
|
|
||||||
except:
|
|
||||||
warnings.warn("using llamav1, llamav1 hasn't repeat_kv function")
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self: LlamaAttention,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[dict] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_value: Optional[Cache] = None,
|
past_key_value: Optional[Cache] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
||||||
|
if sp_mode is not None:
|
||||||
|
assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode"
|
||||||
|
assert (sp_size is not None) and (
|
||||||
|
sp_group is not None
|
||||||
|
), "Must specify sp_size and sp_group for sequence parallel"
|
||||||
if "padding_mask" in kwargs:
|
if "padding_mask" in kwargs:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||||
)
|
)
|
||||||
bsz, q_len, _ = hidden_states.size()
|
|
||||||
|
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
# sp: modify sp_len when sequence parallel mode is ring
|
||||||
if sp_mode in ["split_gather", "ring"]:
|
if sp_mode in ["split_gather", "ring"]:
|
||||||
q_len *= sp_size
|
q_len *= sp_size
|
||||||
|
|
||||||
|
if self.config.pretraining_tp > 1:
|
||||||
|
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
||||||
|
query_slices = self.q_proj.weight.split(
|
||||||
|
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
||||||
|
)
|
||||||
|
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
||||||
|
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
||||||
|
|
||||||
|
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||||
|
query_states = torch.cat(query_states, dim=-1)
|
||||||
|
|
||||||
|
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||||
|
key_states = torch.cat(key_states, dim=-1)
|
||||||
|
|
||||||
|
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||||
|
value_states = torch.cat(value_states, dim=-1)
|
||||||
|
else:
|
||||||
query_states = self.q_proj(hidden_states)
|
query_states = self.q_proj(hidden_states)
|
||||||
key_states = self.k_proj(hidden_states)
|
key_states = self.k_proj(hidden_states)
|
||||||
value_states = self.v_proj(hidden_states)
|
value_states = self.v_proj(hidden_states)
|
||||||
|
@ -520,338 +527,24 @@ def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size):
|
||||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||||
"with a layer index."
|
"with a layer index."
|
||||||
)
|
)
|
||||||
|
|
||||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
|
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
||||||
|
|
||||||
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
|
|
||||||
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
|
||||||
|
|
||||||
# sp: all-to-all comminucation when introducing sequence parallel
|
|
||||||
if sp_mode == "all_to_all":
|
|
||||||
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2)
|
|
||||||
attn_output = self.o_proj(attn_output)
|
|
||||||
|
|
||||||
return attn_output, None, past_key_value
|
|
||||||
|
|
||||||
return forward
|
|
||||||
|
|
||||||
|
|
||||||
def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig):
|
|
||||||
logger = logging.get_logger(__name__)
|
|
||||||
assert shard_config.enable_flash_attention, "Flash Attention is not enabled."
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self: LlamaModel,
|
|
||||||
input_ids: torch.LongTensor = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
||||||
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
# retrieve input_ids and inputs_embeds
|
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
|
||||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
|
||||||
elif input_ids is not None:
|
|
||||||
batch_size, seq_length = input_ids.shape
|
|
||||||
elif inputs_embeds is not None:
|
|
||||||
batch_size, seq_length, _ = inputs_embeds.shape
|
|
||||||
else:
|
|
||||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
|
||||||
|
|
||||||
seq_length_with_past = seq_length
|
|
||||||
past_key_values_length = 0
|
|
||||||
|
|
||||||
if past_key_values is not None:
|
|
||||||
past_key_values_length = past_key_values[0][0].shape[2]
|
|
||||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
|
||||||
|
|
||||||
if position_ids is None:
|
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
||||||
position_ids = torch.arange(
|
|
||||||
past_key_values_length,
|
|
||||||
seq_length + past_key_values_length,
|
|
||||||
dtype=torch.long,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
|
||||||
else:
|
|
||||||
position_ids = position_ids.view(-1, seq_length).long()
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
|
||||||
# embed positions
|
|
||||||
hidden_states = inputs_embeds
|
|
||||||
|
|
||||||
# in this case, attention_mask is a dict rather than a tensor
|
|
||||||
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
|
|
||||||
attention_mask = ColoAttention.prepare_attn_kwargs(
|
|
||||||
mask_shape,
|
|
||||||
hidden_states.dtype,
|
|
||||||
hidden_states.device,
|
|
||||||
q_padding_mask=attention_mask,
|
|
||||||
is_causal=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
|
||||||
if use_cache:
|
|
||||||
logger.warning_once(
|
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
||||||
)
|
|
||||||
use_cache = False
|
|
||||||
|
|
||||||
# decoder layers
|
|
||||||
all_hidden_states = () if output_hidden_states else None
|
|
||||||
all_self_attns = () if output_attentions else None
|
|
||||||
next_decoder_cache = () if use_cache else None
|
|
||||||
|
|
||||||
for idx, decoder_layer in enumerate(self.layers):
|
|
||||||
if output_hidden_states:
|
|
||||||
all_hidden_states += (hidden_states,)
|
|
||||||
|
|
||||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
|
||||||
|
|
||||||
def create_custom_forward(module):
|
|
||||||
def custom_forward(*inputs):
|
|
||||||
# None for past_key_value
|
|
||||||
return module(*inputs, past_key_value, output_attentions)
|
|
||||||
|
|
||||||
return custom_forward
|
|
||||||
|
|
||||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(decoder_layer),
|
|
||||||
hidden_states,
|
|
||||||
attention_mask,
|
|
||||||
position_ids,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_value,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
|
||||||
|
|
||||||
if use_cache:
|
|
||||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
|
||||||
|
|
||||||
if output_attentions:
|
|
||||||
all_self_attns += (layer_outputs[1],)
|
|
||||||
|
|
||||||
hidden_states = self.norm(hidden_states)
|
|
||||||
|
|
||||||
# add hidden states from the last decoder layer
|
|
||||||
if output_hidden_states:
|
|
||||||
all_hidden_states += (hidden_states,)
|
|
||||||
|
|
||||||
next_cache = next_decoder_cache if use_cache else None
|
|
||||||
if not return_dict:
|
|
||||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
|
||||||
return BaseModelOutputWithPast(
|
|
||||||
last_hidden_state=hidden_states,
|
|
||||||
past_key_values=next_cache,
|
|
||||||
hidden_states=all_hidden_states,
|
|
||||||
attentions=all_self_attns,
|
|
||||||
)
|
|
||||||
|
|
||||||
return forward
|
|
||||||
|
|
||||||
|
|
||||||
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|
||||||
from transformers import LlamaForCausalLM
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self: LlamaForCausalLM,
|
|
||||||
input_ids: torch.LongTensor = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
Args:
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
|
||||||
|
|
||||||
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
|
||||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
|
||||||
|
|
||||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
|
||||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
||||||
|
|
||||||
>>> # Generate
|
|
||||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
||||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
||||||
```"""
|
|
||||||
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
if self.config.pretraining_tp > 1:
|
|
||||||
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
|
|
||||||
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
|
|
||||||
logits = torch.cat(logits, dim=-1)
|
|
||||||
else:
|
|
||||||
logits = self.lm_head(hidden_states)
|
|
||||||
logits = logits.float()
|
|
||||||
|
|
||||||
loss = None
|
|
||||||
if labels is not None:
|
|
||||||
# Shift so that tokens < n predict n
|
|
||||||
shift_logits = logits[..., :-1, :].contiguous()
|
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
|
||||||
shift_labels = shift_labels.view(-1)
|
|
||||||
# Enable model parallelism
|
|
||||||
shift_labels = shift_labels.to(shift_logits.device)
|
|
||||||
new_vocab_size = logits.shape[-1]
|
|
||||||
shift_logits = shift_logits.view(-1, new_vocab_size)
|
|
||||||
loss = cross_entropy_1d(
|
|
||||||
shift_logits,
|
|
||||||
shift_labels,
|
|
||||||
process_group=shard_config.tensor_parallel_process_group,
|
|
||||||
vocab_size=self.lm_head.out_features,
|
|
||||||
dtype=self.model.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (logits,) + outputs[1:]
|
|
||||||
return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
return forward
|
|
||||||
|
|
||||||
|
|
||||||
def get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group):
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
||||||
output_attentions: bool = False,
|
|
||||||
use_cache: bool = False,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
|
||||||
# sp: modify sp_len when sequence parallel mode is ring
|
|
||||||
if sp_mode in ["split_gather", "ring"]:
|
|
||||||
q_len *= sp_size
|
|
||||||
if self.config.pretraining_tp > 1:
|
|
||||||
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
|
||||||
query_slices = self.q_proj.weight.split(
|
|
||||||
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
|
||||||
)
|
|
||||||
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
|
||||||
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
|
||||||
|
|
||||||
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
|
||||||
query_states = torch.cat(query_states, dim=-1)
|
|
||||||
|
|
||||||
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
|
||||||
key_states = torch.cat(key_states, dim=-1)
|
|
||||||
|
|
||||||
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
|
||||||
value_states = torch.cat(value_states, dim=-1)
|
|
||||||
|
|
||||||
else:
|
|
||||||
query_states = self.q_proj(hidden_states)
|
|
||||||
key_states = self.k_proj(hidden_states)
|
|
||||||
value_states = self.v_proj(hidden_states)
|
|
||||||
|
|
||||||
# sp: all-to-all comminucation when introducing sequence parallel
|
|
||||||
if sp_mode == "all_to_all":
|
|
||||||
query_states = all_to_all_comm(query_states, sp_group)
|
|
||||||
key_states = all_to_all_comm(key_states, sp_group)
|
|
||||||
value_states = all_to_all_comm(value_states, sp_group)
|
|
||||||
bsz, q_len, _ = query_states.size()
|
|
||||||
|
|
||||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
||||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
|
||||||
if past_key_value is not None:
|
|
||||||
kv_seq_len += past_key_value[0].shape[-2]
|
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
|
||||||
|
|
||||||
if past_key_value is not None:
|
|
||||||
# reuse k, v, self_attention
|
|
||||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
|
||||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
|
||||||
|
|
||||||
past_key_value = (key_states, value_states) if use_cache else None
|
|
||||||
|
|
||||||
# repeat k/v heads if n_kv_heads < n_heads
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
if shard_config.enable_flash_attention:
|
||||||
|
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
|
||||||
|
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
|
||||||
|
else:
|
||||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||||
|
@ -899,7 +592,7 @@ def get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group):
|
||||||
return forward
|
return forward
|
||||||
|
|
||||||
|
|
||||||
def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
|
def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -913,66 +606,20 @@ def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
)
|
)
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
# retrieve input_ids and inputs_embeds
|
# retrieve input_ids and inputs_embeds
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
raise ValueError(
|
||||||
elif input_ids is not None:
|
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
||||||
batch_size, seq_length = input_ids.shape
|
|
||||||
elif inputs_embeds is not None:
|
|
||||||
batch_size, seq_length, _ = inputs_embeds.shape
|
|
||||||
else:
|
|
||||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
|
||||||
|
|
||||||
seq_length_with_past = seq_length
|
|
||||||
past_key_values_length = 0
|
|
||||||
|
|
||||||
if past_key_values is not None:
|
|
||||||
past_key_values_length = past_key_values[0][0].shape[2]
|
|
||||||
# modify past_key_values_length when using sequence parallel
|
|
||||||
past_key_values_length *= sp_size
|
|
||||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
|
||||||
|
|
||||||
if position_ids is None:
|
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
||||||
position_ids = torch.arange(
|
|
||||||
past_key_values_length,
|
|
||||||
seq_length + past_key_values_length,
|
|
||||||
dtype=torch.long,
|
|
||||||
device=device,
|
|
||||||
)
|
)
|
||||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
|
||||||
else:
|
|
||||||
position_ids = position_ids.view(-1, seq_length).long()
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
|
||||||
|
|
||||||
if sp_mode in ["ring", "split_gather"]:
|
|
||||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
|
|
||||||
elif sp_mode == "all_to_all":
|
|
||||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
|
|
||||||
|
|
||||||
if attention_mask is None:
|
|
||||||
attention_mask = torch.ones(
|
|
||||||
(batch_size, seq_length_with_past),
|
|
||||||
dtype=torch.bool,
|
|
||||||
device=inputs_embeds.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
attention_mask = _prepare_4d_causal_attention_mask(
|
|
||||||
attention_mask, attention_mask.shape, inputs_embeds, past_key_values_length
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = inputs_embeds
|
|
||||||
|
|
||||||
if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training:
|
if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training:
|
||||||
if use_cache:
|
if use_cache:
|
||||||
|
@ -981,31 +628,60 @@ def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
|
||||||
)
|
)
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
past_seen_tokens = 0
|
||||||
|
seq_len = inputs_embeds.shape[1]
|
||||||
|
if use_cache: # kept for BC (cache positions)
|
||||||
|
if not isinstance(past_key_values, StaticCache):
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length()
|
||||||
|
if cache_position is None:
|
||||||
|
if isinstance(past_key_values, StaticCache):
|
||||||
|
raise ValueError("cache_position is a required argument when using StaticCache.")
|
||||||
|
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_len, device=inputs_embeds.device)
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
|
# in this case, attention_mask is a dict rather than a tensor
|
||||||
|
if shard_config.enable_flash_attention:
|
||||||
|
mask_shape = (inputs_embeds.shape[0], 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len)
|
||||||
|
attention_mask = ColoAttention.prepare_attn_kwargs(
|
||||||
|
mask_shape,
|
||||||
|
inputs_embeds.dtype,
|
||||||
|
inputs_embeds.device,
|
||||||
|
q_padding_mask=attention_mask,
|
||||||
|
is_causal=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
|
||||||
|
|
||||||
|
if sp_mode in ["ring", "split_gather"]:
|
||||||
|
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
|
||||||
|
elif sp_mode == "all_to_all":
|
||||||
|
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
# decoder layers
|
# decoder layers
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
next_decoder_cache = () if use_cache else None
|
next_decoder_cache = None
|
||||||
|
|
||||||
for idx, decoder_layer in enumerate(self.layers):
|
for decoder_layer in self.layers:
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
|
decoder_layer.__call__,
|
||||||
if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training:
|
|
||||||
|
|
||||||
def create_custom_forward(module):
|
|
||||||
def custom_forward(*inputs):
|
|
||||||
# None for past_key_value
|
|
||||||
return module(*inputs, past_key_value, output_attentions)
|
|
||||||
|
|
||||||
return custom_forward
|
|
||||||
|
|
||||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(decoder_layer),
|
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
past_key_values,
|
||||||
|
output_attentions,
|
||||||
|
use_cache,
|
||||||
|
cache_position,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
@ -1013,15 +689,16 @@ def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
past_key_value=past_key_value,
|
past_key_value=past_key_values,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attns += (layer_outputs[1],)
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
@ -1037,7 +714,11 @@ def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
next_cache = next_decoder_cache if use_cache else None
|
next_cache = None
|
||||||
|
if use_cache:
|
||||||
|
next_cache = (
|
||||||
|
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
|
||||||
|
)
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||||
|
|
||||||
|
@ -1049,3 +730,108 @@ def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
|
||||||
)
|
)
|
||||||
|
|
||||||
return forward
|
return forward
|
||||||
|
|
||||||
|
|
||||||
|
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||||
|
from transformers import LlamaForCausalLM
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self: LlamaForCausalLM,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
||||||
|
|
||||||
|
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||||
|
|
||||||
|
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||||
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||||
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||||
|
```"""
|
||||||
|
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
if self.config.pretraining_tp > 1:
|
||||||
|
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
|
||||||
|
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||||
|
logits = torch.cat(logits, dim=-1)
|
||||||
|
else:
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
logits = logits.float()
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
# Enable model parallelism
|
||||||
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
new_vocab_size = logits.shape[-1]
|
||||||
|
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||||
|
loss = cross_entropy_1d(
|
||||||
|
shift_logits,
|
||||||
|
shift_labels,
|
||||||
|
process_group=shard_config.tensor_parallel_process_group,
|
||||||
|
vocab_size=self.lm_head.out_features,
|
||||||
|
dtype=self.model.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
return forward
|
||||||
|
|
|
@ -4,7 +4,10 @@ from typing import List, Optional, Tuple, Union
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
from transformers.cache_utils import Cache, DynamicCache
|
from transformers.cache_utils import Cache, DynamicCache
|
||||||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
from transformers.modeling_attn_mask_utils import (
|
||||||
|
_prepare_4d_causal_attention_mask,
|
||||||
|
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||||
|
)
|
||||||
from transformers.modeling_outputs import (
|
from transformers.modeling_outputs import (
|
||||||
BaseModelOutputWithPast,
|
BaseModelOutputWithPast,
|
||||||
CausalLMOutputWithPast,
|
CausalLMOutputWithPast,
|
||||||
|
@ -77,7 +80,7 @@ class MistralForwards:
|
||||||
else:
|
else:
|
||||||
position_ids = position_ids.view(-1, seq_length).long()
|
position_ids = position_ids.view(-1, seq_length).long()
|
||||||
|
|
||||||
if attention_mask is not None and self._use_flash_attention_2 and use_cache:
|
if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
|
||||||
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
||||||
if is_padding_right:
|
if is_padding_right:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -97,9 +100,18 @@ class MistralForwards:
|
||||||
is_causal=True,
|
is_causal=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if self._use_flash_attention_2:
|
if self._attn_implementation == "flash_attention_2":
|
||||||
# 2d mask is passed through the layers
|
# 2d mask is passed through the layers
|
||||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||||
|
elif self._attn_implementation == "sdpa" and not output_attentions:
|
||||||
|
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
||||||
|
# the manual implementation that requires a 4D causal mask in all cases.
|
||||||
|
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||||
|
attention_mask,
|
||||||
|
(batch_size, seq_length),
|
||||||
|
inputs_embeds,
|
||||||
|
past_key_values_length,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# 4d mask is passed through the layers
|
# 4d mask is passed through the layers
|
||||||
attention_mask = _prepare_4d_causal_attention_mask(
|
attention_mask = _prepare_4d_causal_attention_mask(
|
||||||
|
@ -462,7 +474,7 @@ def get_mistral_model_forward_for_flash_attn(shard_config: ShardConfig):
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
if attention_mask is not None and self._use_flash_attention_2 and use_cache:
|
if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
|
||||||
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
||||||
if is_padding_right:
|
if is_padding_right:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -481,9 +493,18 @@ def get_mistral_model_forward_for_flash_attn(shard_config: ShardConfig):
|
||||||
is_causal=True,
|
is_causal=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if self._use_flash_attention_2:
|
if self._attn_implementation == "flash_attention_2":
|
||||||
# 2d mask is passed through the layers
|
# 2d mask is passed through the layers
|
||||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||||
|
elif self._attn_implementation == "sdpa" and not output_attentions:
|
||||||
|
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
||||||
|
# the manual implementation that requires a 4D causal mask in all cases.
|
||||||
|
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||||
|
attention_mask,
|
||||||
|
(batch_size, seq_length),
|
||||||
|
inputs_embeds,
|
||||||
|
past_key_values_length,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# 4d mask is passed through the layers
|
# 4d mask is passed through the layers
|
||||||
attention_mask = _prepare_4d_causal_attention_mask(
|
attention_mask = _prepare_4d_causal_attention_mask(
|
||||||
|
|
|
@ -17,6 +17,7 @@ from transformers.modeling_outputs import (
|
||||||
SequenceClassifierOutput,
|
SequenceClassifierOutput,
|
||||||
)
|
)
|
||||||
from transformers.models.whisper.modeling_whisper import (
|
from transformers.models.whisper.modeling_whisper import (
|
||||||
|
_HIDDEN_STATES_START_POSITION,
|
||||||
WhisperDecoder,
|
WhisperDecoder,
|
||||||
WhisperEncoder,
|
WhisperEncoder,
|
||||||
WhisperForAudioClassification,
|
WhisperForAudioClassification,
|
||||||
|
@ -166,6 +167,7 @@ def get_whisper_decoder_forward_for_flash_attention(shard_config: ShardConfig):
|
||||||
cross_attn_head_mask=None,
|
cross_attn_head_mask=None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
|
position_ids=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
|
@ -199,9 +201,13 @@ def get_whisper_decoder_forward_for_flash_attention(shard_config: ShardConfig):
|
||||||
|
|
||||||
# embed positions
|
# embed positions
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
|
positions = self.embed_positions(
|
||||||
|
input_ids, past_key_values_length=past_key_values_length, position_ids=position_ids
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length)
|
positions = self.embed_positions(
|
||||||
|
inputs_embeds, past_key_values_length=past_key_values_length, position_ids=position_ids
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states = inputs_embeds + positions
|
hidden_states = inputs_embeds + positions
|
||||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||||
|
@ -599,6 +605,7 @@ class WhisperPipelineForwards:
|
||||||
cross_attn_head_mask=None,
|
cross_attn_head_mask=None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
|
position_ids=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
|
@ -716,9 +723,13 @@ class WhisperPipelineForwards:
|
||||||
|
|
||||||
# embed positions
|
# embed positions
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
|
positions = self.embed_positions(
|
||||||
|
input_ids, past_key_values_length=past_key_values_length, position_ids=position_ids
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length)
|
positions = self.embed_positions(
|
||||||
|
inputs_embeds, past_key_values_length=past_key_values_length, position_ids=position_ids
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states = inputs_embeds + positions
|
hidden_states = inputs_embeds + positions
|
||||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||||
|
@ -841,6 +852,7 @@ class WhisperPipelineForwards:
|
||||||
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
|
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
|
||||||
|
decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
@ -944,6 +956,7 @@ class WhisperPipelineForwards:
|
||||||
cross_attn_head_mask=cross_attn_head_mask,
|
cross_attn_head_mask=cross_attn_head_mask,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=decoder_inputs_embeds,
|
inputs_embeds=decoder_inputs_embeds,
|
||||||
|
position_ids=decoder_position_ids,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
|
@ -986,6 +999,7 @@ class WhisperPipelineForwards:
|
||||||
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
|
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
|
||||||
|
decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
|
@ -1048,6 +1062,7 @@ class WhisperPipelineForwards:
|
||||||
cross_attn_head_mask=cross_attn_head_mask,
|
cross_attn_head_mask=cross_attn_head_mask,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||||
|
decoder_position_ids=decoder_position_ids,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
|
@ -1118,6 +1133,12 @@ class WhisperPipelineForwards:
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.config.use_weighted_layer_sum:
|
||||||
|
output_hidden_states = True
|
||||||
|
elif output_hidden_states is None:
|
||||||
|
output_hidden_states = self.config.output_hidden_states
|
||||||
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
# audio_classification only holds encoder
|
# audio_classification only holds encoder
|
||||||
|
@ -1138,7 +1159,8 @@ class WhisperPipelineForwards:
|
||||||
return encoder_outputs
|
return encoder_outputs
|
||||||
|
|
||||||
if self.config.use_weighted_layer_sum:
|
if self.config.use_weighted_layer_sum:
|
||||||
hidden_states = torch.stack(encoder_outputs, dim=1)
|
hidden_states = encoder_outputs[_HIDDEN_STATES_START_POSITION]
|
||||||
|
hidden_states = torch.stack(hidden_states, dim=1)
|
||||||
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
||||||
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -192,6 +192,13 @@ _POLICY_LIST = {
|
||||||
"transformers.models.qwen2.modeling_qwen2.Qwen2ForSequenceClassification": PolicyLocation(
|
"transformers.models.qwen2.modeling_qwen2.Qwen2ForSequenceClassification": PolicyLocation(
|
||||||
file_name="qwen2", class_name="Qwen2ForSequenceClassificationPolicy"
|
file_name="qwen2", class_name="Qwen2ForSequenceClassificationPolicy"
|
||||||
),
|
),
|
||||||
|
# Command-R
|
||||||
|
"transformers.models.cohere.modeling_cohere.CohereModel": PolicyLocation(
|
||||||
|
file_name="command", class_name="CommandModelPolicy"
|
||||||
|
),
|
||||||
|
"transformers.models.cohere.modeling_cohere.CohereForCausalLM": PolicyLocation(
|
||||||
|
file_name="command", class_name="CommandForCausalLMPolicy"
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -67,7 +67,7 @@ class BertPolicy(Policy):
|
||||||
else:
|
else:
|
||||||
norm_cls = col_nn.LayerNorm
|
norm_cls = col_nn.LayerNorm
|
||||||
|
|
||||||
sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None
|
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
||||||
assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for Bert"
|
assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for Bert"
|
||||||
if sp_mode == "ring":
|
if sp_mode == "ring":
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
|
|
@ -50,7 +50,7 @@ class BloomPolicy(Policy):
|
||||||
else:
|
else:
|
||||||
norm_cls = col_nn.LayerNorm
|
norm_cls = col_nn.LayerNorm
|
||||||
|
|
||||||
sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None
|
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
||||||
assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for BLOOM"
|
assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for BLOOM"
|
||||||
if sp_mode == "ring":
|
if sp_mode == "ring":
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
|
|
@ -57,7 +57,7 @@ class ChatGLMPolicy(Policy):
|
||||||
else:
|
else:
|
||||||
norm_cls = col_nn.LayerNorm
|
norm_cls = col_nn.LayerNorm
|
||||||
|
|
||||||
sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None
|
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
||||||
assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for ChatGLM2"
|
assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for ChatGLM2"
|
||||||
if sp_mode == "ring":
|
if sp_mode == "ring":
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
|
|
@ -0,0 +1,369 @@
|
||||||
|
import warnings
|
||||||
|
from functools import partial
|
||||||
|
from typing import Callable, Dict, List, Union
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.nn import Module
|
||||||
|
|
||||||
|
from colossalai.shardformer.layer import (
|
||||||
|
FusedLayerNorm,
|
||||||
|
LayerNorm,
|
||||||
|
Linear1D_Col,
|
||||||
|
Linear1D_Row,
|
||||||
|
PaddingEmbedding,
|
||||||
|
PaddingLMHead,
|
||||||
|
VocabParallelEmbedding1D,
|
||||||
|
VocabParallelLMHead1D,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ..modeling.command import (
|
||||||
|
CommandPipelineForwards,
|
||||||
|
get_command_flash_attention_forward,
|
||||||
|
get_command_flash_attention_model_forward,
|
||||||
|
get_lm_forward_with_dist_cross_entropy,
|
||||||
|
)
|
||||||
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
|
__all__ = ["CommandPolicy", "CommandForCausalLMPolicy"]
|
||||||
|
|
||||||
|
|
||||||
|
class CommandPolicy(Policy):
|
||||||
|
def config_sanity_check(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def preprocess(self):
|
||||||
|
self.tie_weight = self.tie_weight_check()
|
||||||
|
self.origin_attn_implement = self.model.config._attn_implementation
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||||
|
from transformers.models.cohere.modeling_cohere import (
|
||||||
|
CohereAttention,
|
||||||
|
CohereDecoderLayer,
|
||||||
|
CohereFlashAttention2,
|
||||||
|
CohereModel,
|
||||||
|
CohereSdpaAttention,
|
||||||
|
)
|
||||||
|
|
||||||
|
ATTN_IMPLEMENTATION = {
|
||||||
|
"eager": CohereAttention,
|
||||||
|
"flash_attention_2": CohereFlashAttention2,
|
||||||
|
"sdpa": CohereSdpaAttention,
|
||||||
|
}
|
||||||
|
policy = {}
|
||||||
|
|
||||||
|
attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
|
||||||
|
embedding_cls = None
|
||||||
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
embedding_cls = VocabParallelEmbedding1D
|
||||||
|
else:
|
||||||
|
if self.tie_weight:
|
||||||
|
embedding_cls = PaddingEmbedding
|
||||||
|
|
||||||
|
if self.shard_config.enable_fused_normalization:
|
||||||
|
norm_cls = FusedLayerNorm
|
||||||
|
else:
|
||||||
|
norm_cls = LayerNorm
|
||||||
|
|
||||||
|
if self.pipeline_stage_manager is not None:
|
||||||
|
self.shard_config.enable_sequence_parallelism = False
|
||||||
|
self.shard_config.enable_sequence_overlap = False
|
||||||
|
self.shard_config.sequence_parallelism_mode = None
|
||||||
|
warnings.warn(
|
||||||
|
f"For Command, sequence parallelism is currently not compatible with pipeline parallelism, set to be False"
|
||||||
|
)
|
||||||
|
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
||||||
|
sp_size = self.shard_config.sequence_parallel_size or None
|
||||||
|
sp_group = self.shard_config.sequence_parallel_process_group or None
|
||||||
|
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
||||||
|
|
||||||
|
if sp_mode == "all_to_all":
|
||||||
|
decoder_attribute_replacement = {
|
||||||
|
"num_heads": self.model.config.num_attention_heads // sp_size,
|
||||||
|
}
|
||||||
|
if getattr(self.model.config, "num_key_value_heads", False):
|
||||||
|
decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size
|
||||||
|
|
||||||
|
policy[attn_cls] = ModulePolicyDescription(
|
||||||
|
attribute_replacement=decoder_attribute_replacement,
|
||||||
|
)
|
||||||
|
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
|
||||||
|
self.append_or_create_method_replacement(
|
||||||
|
description={
|
||||||
|
"forward": get_command_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
|
||||||
|
},
|
||||||
|
policy=policy,
|
||||||
|
target_key=attn_cls,
|
||||||
|
)
|
||||||
|
if self.pipeline_stage_manager is None:
|
||||||
|
self.append_or_create_method_replacement(
|
||||||
|
description={
|
||||||
|
"forward": get_command_flash_attention_model_forward(
|
||||||
|
self.shard_config,
|
||||||
|
sp_mode=sp_mode,
|
||||||
|
sp_size=sp_size,
|
||||||
|
sp_group=sp_group,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
policy=policy,
|
||||||
|
target_key=CohereModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
assert (
|
||||||
|
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||||
|
), f"The number of attention heads must be divisible by tensor parallel size."
|
||||||
|
if hasattr(self.model.config, "num_key_value_heads"):
|
||||||
|
assert (
|
||||||
|
self.model.config.num_key_value_heads >= self.shard_config.tensor_parallel_size
|
||||||
|
and self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
|
||||||
|
), f"The number of key_value heads must be divisible by, and must not be less than tensor parallel size."
|
||||||
|
decoder_attribute_replacement = {
|
||||||
|
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||||
|
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||||
|
}
|
||||||
|
if getattr(self.model.config, "num_key_value_heads", False):
|
||||||
|
decoder_attribute_replacement["self_attn.num_key_value_heads"] = (
|
||||||
|
self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size
|
||||||
|
)
|
||||||
|
|
||||||
|
policy[CohereDecoderLayer] = ModulePolicyDescription(
|
||||||
|
attribute_replacement=decoder_attribute_replacement,
|
||||||
|
sub_module_replacement=[
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn.q_proj",
|
||||||
|
target_module=Linear1D_Col,
|
||||||
|
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn.k_proj",
|
||||||
|
target_module=Linear1D_Col,
|
||||||
|
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn.v_proj",
|
||||||
|
target_module=Linear1D_Col,
|
||||||
|
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn.o_proj",
|
||||||
|
target_module=Linear1D_Row,
|
||||||
|
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="mlp.gate_proj",
|
||||||
|
target_module=Linear1D_Col,
|
||||||
|
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="mlp.up_proj",
|
||||||
|
target_module=Linear1D_Col,
|
||||||
|
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="mlp.down_proj",
|
||||||
|
target_module=Linear1D_Row,
|
||||||
|
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
if embedding_cls is not None:
|
||||||
|
self.append_or_create_submodule_replacement(
|
||||||
|
description=SubModuleReplacementDescription(
|
||||||
|
suffix="embed_tokens",
|
||||||
|
target_module=embedding_cls,
|
||||||
|
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||||
|
),
|
||||||
|
policy=policy,
|
||||||
|
target_key=CohereModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
# optimization configuration
|
||||||
|
self.append_or_create_submodule_replacement(
|
||||||
|
description=[
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="input_layernorm",
|
||||||
|
target_module=norm_cls,
|
||||||
|
kwargs={"sp_partial_derived": sp_partial_derived},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
policy=policy,
|
||||||
|
target_key=CohereDecoderLayer,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.append_or_create_submodule_replacement(
|
||||||
|
description=SubModuleReplacementDescription(
|
||||||
|
suffix="norm",
|
||||||
|
target_module=norm_cls,
|
||||||
|
kwargs={"sp_partial_derived": sp_partial_derived},
|
||||||
|
),
|
||||||
|
policy=policy,
|
||||||
|
target_key=CohereModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
return policy
|
||||||
|
|
||||||
|
def postprocess(self):
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
|
||||||
|
"""If under pipeline parallel setting, replacing the original forward method of huggingface
|
||||||
|
to customized forward method, and add this changing to policy."""
|
||||||
|
if self.pipeline_stage_manager is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
if self.model.__class__.__name__ == "CohereModel":
|
||||||
|
module = self.model
|
||||||
|
else:
|
||||||
|
module = self.model.model
|
||||||
|
|
||||||
|
if stage_manager.is_interleave:
|
||||||
|
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||||
|
stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage)
|
||||||
|
method_replacement = {
|
||||||
|
"forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config)
|
||||||
|
}
|
||||||
|
|
||||||
|
else:
|
||||||
|
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||||
|
stage_index = stage_manager.get_stage_index(layers_per_stage)
|
||||||
|
method_replacement = {
|
||||||
|
"forward": partial(
|
||||||
|
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
|
||||||
|
|
||||||
|
def get_held_layers(self) -> List[Module]:
|
||||||
|
"""Get pipeline layers for current stage."""
|
||||||
|
assert self.pipeline_stage_manager is not None
|
||||||
|
|
||||||
|
if self.model.__class__.__name__ == "CohereModel":
|
||||||
|
module = self.model
|
||||||
|
else:
|
||||||
|
module = self.model.model
|
||||||
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
|
||||||
|
held_layers = []
|
||||||
|
if stage_manager.is_interleave:
|
||||||
|
assert stage_manager.num_model_chunks is not None
|
||||||
|
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||||
|
stage_indices = stage_manager.get_stage_index(layers_per_stage)
|
||||||
|
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
|
held_layers.append(module.embed_tokens)
|
||||||
|
for start_idx, end_idx in stage_indices:
|
||||||
|
held_layers.extend(module.layers[start_idx:end_idx])
|
||||||
|
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
|
held_layers.append(module.norm)
|
||||||
|
|
||||||
|
else:
|
||||||
|
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||||
|
if stage_manager.is_first_stage():
|
||||||
|
held_layers.append(module.embed_tokens)
|
||||||
|
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||||
|
held_layers.extend(module.layers[start_idx:end_idx])
|
||||||
|
if stage_manager.is_last_stage():
|
||||||
|
held_layers.append(module.norm)
|
||||||
|
|
||||||
|
return held_layers
|
||||||
|
|
||||||
|
|
||||||
|
class CommandModelPolicy(CommandPolicy):
|
||||||
|
def module_policy(self):
|
||||||
|
policy = super().module_policy()
|
||||||
|
from transformers.models.cohere.modeling_cohere import CohereModel
|
||||||
|
|
||||||
|
if self.pipeline_stage_manager:
|
||||||
|
# set None as default
|
||||||
|
self.set_pipeline_forward(
|
||||||
|
model_cls=CohereModel, new_forward=CommandPipelineForwards.command_model_forward, policy=policy
|
||||||
|
)
|
||||||
|
return policy
|
||||||
|
|
||||||
|
def get_held_layers(self) -> List[Module]:
|
||||||
|
"""Get pipeline layers for current stage."""
|
||||||
|
held_layers = super().get_held_layers()
|
||||||
|
return held_layers
|
||||||
|
|
||||||
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||||
|
"""No shared params in command model"""
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class CommandForCausalLMPolicy(CommandPolicy):
|
||||||
|
def module_policy(self):
|
||||||
|
from transformers import CohereForCausalLM
|
||||||
|
|
||||||
|
policy = super().module_policy()
|
||||||
|
|
||||||
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
# add a new item for casual lm
|
||||||
|
new_item = {
|
||||||
|
CohereForCausalLM: ModulePolicyDescription(
|
||||||
|
sub_module_replacement=[
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="lm_head",
|
||||||
|
target_module=VocabParallelLMHead1D,
|
||||||
|
kwargs={
|
||||||
|
"gather_output": not self.shard_config.parallel_output,
|
||||||
|
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if self.shard_config.parallel_output:
|
||||||
|
new_item[CohereForCausalLM].method_replacement = {
|
||||||
|
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
new_item = {
|
||||||
|
CohereForCausalLM: ModulePolicyDescription(
|
||||||
|
sub_module_replacement=[
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="lm_head",
|
||||||
|
target_module=PaddingLMHead,
|
||||||
|
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
}
|
||||||
|
policy.update(new_item)
|
||||||
|
|
||||||
|
if self.pipeline_stage_manager:
|
||||||
|
# set None as default
|
||||||
|
self.set_pipeline_forward(
|
||||||
|
model_cls=CohereForCausalLM,
|
||||||
|
new_forward=CommandPipelineForwards.command_for_causal_lm_forward,
|
||||||
|
policy=policy,
|
||||||
|
)
|
||||||
|
|
||||||
|
return policy
|
||||||
|
|
||||||
|
def get_held_layers(self) -> List[Module]:
|
||||||
|
"""Get pipeline layers for current stage."""
|
||||||
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
held_layers = super().get_held_layers()
|
||||||
|
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
|
held_layers.append(self.model.lm_head)
|
||||||
|
return held_layers
|
||||||
|
|
||||||
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||||
|
command_model = self.model.model
|
||||||
|
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
|
||||||
|
if (
|
||||||
|
id(command_model.embed_tokens.weight) == id(self.model.lm_head.weight)
|
||||||
|
and self.pipeline_stage_manager.num_stages > 1
|
||||||
|
):
|
||||||
|
# tie weights
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
0: command_model.embed_tokens.weight,
|
||||||
|
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
return []
|
|
@ -65,7 +65,7 @@ class GPT2Policy(Policy):
|
||||||
else:
|
else:
|
||||||
norm_cls = col_nn.LayerNorm
|
norm_cls = col_nn.LayerNorm
|
||||||
|
|
||||||
sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None
|
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
||||||
assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for GPT2"
|
assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for GPT2"
|
||||||
if sp_mode == "ring":
|
if sp_mode == "ring":
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
|
|
@ -34,15 +34,11 @@ class GPTJPolicy(Policy):
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJBlock, GPTJModel
|
from transformers.models.gptj.modeling_gptj import GPTJ_ATTENTION_CLASSES, GPTJBlock, GPTJModel
|
||||||
|
|
||||||
ATTN_IMPLEMENTATION = {
|
|
||||||
"eager": GPTJAttention,
|
|
||||||
}
|
|
||||||
|
|
||||||
policy = {}
|
policy = {}
|
||||||
|
|
||||||
attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
|
attn_cls = GPTJ_ATTENTION_CLASSES[self.origin_attn_implement]
|
||||||
|
|
||||||
embedding_cls = None
|
embedding_cls = None
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
|
|
@ -20,9 +20,7 @@ from colossalai.shardformer.layer import (
|
||||||
from ..modeling.llama import (
|
from ..modeling.llama import (
|
||||||
LlamaPipelineForwards,
|
LlamaPipelineForwards,
|
||||||
get_llama_flash_attention_forward,
|
get_llama_flash_attention_forward,
|
||||||
get_llama_model_forward_for_flash_attn,
|
get_llama_flash_attention_model_forward,
|
||||||
get_llama_seq_parallel_attention_forward,
|
|
||||||
get_llama_seq_parallel_model_forward,
|
|
||||||
get_lm_forward_with_dist_cross_entropy,
|
get_lm_forward_with_dist_cross_entropy,
|
||||||
)
|
)
|
||||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
@ -75,40 +73,12 @@ class LlamaPolicy(Policy):
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"For llama, sequence parallelism is currently not compatible with pipeline parallelism, set to be False"
|
f"For llama, sequence parallelism is currently not compatible with pipeline parallelism, set to be False"
|
||||||
)
|
)
|
||||||
sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None
|
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
||||||
sp_size = self.shard_config.sequence_parallel_size if self.shard_config.enable_sequence_parallelism else None
|
sp_size = self.shard_config.sequence_parallel_size or None
|
||||||
sp_group = (
|
sp_group = self.shard_config.sequence_parallel_process_group or None
|
||||||
self.shard_config.sequence_parallel_process_group if self.shard_config.enable_sequence_parallelism else None
|
|
||||||
)
|
|
||||||
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
||||||
|
|
||||||
use_flash_attention = self.shard_config.enable_flash_attention
|
if sp_mode == "all_to_all":
|
||||||
# Currently sp cannot to be used with flashattention
|
|
||||||
if sp_mode in ["split_gather", "ring", "all_to_all"]:
|
|
||||||
if use_flash_attention:
|
|
||||||
warnings.warn(
|
|
||||||
f"Sequence parallelism mode {sp_mode} need to be used with FlashAttention, will disable FlashAttention automatically."
|
|
||||||
)
|
|
||||||
use_flash_attention = False
|
|
||||||
|
|
||||||
if sp_mode in ["split_gather", "ring"]:
|
|
||||||
self.append_or_create_method_replacement(
|
|
||||||
description={
|
|
||||||
"forward": get_llama_seq_parallel_model_forward(
|
|
||||||
sp_mode=sp_mode, sp_size=sp_size, sp_group=sp_group
|
|
||||||
),
|
|
||||||
},
|
|
||||||
policy=policy,
|
|
||||||
target_key=LlamaModel,
|
|
||||||
)
|
|
||||||
self.append_or_create_method_replacement(
|
|
||||||
description={
|
|
||||||
"forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group),
|
|
||||||
},
|
|
||||||
policy=policy,
|
|
||||||
target_key=attn_cls,
|
|
||||||
)
|
|
||||||
elif sp_mode == "all_to_all":
|
|
||||||
decoder_attribute_replacement = {
|
decoder_attribute_replacement = {
|
||||||
"num_heads": self.model.config.num_attention_heads // sp_size,
|
"num_heads": self.model.config.num_attention_heads // sp_size,
|
||||||
}
|
}
|
||||||
|
@ -118,16 +88,19 @@ class LlamaPolicy(Policy):
|
||||||
policy[attn_cls] = ModulePolicyDescription(
|
policy[attn_cls] = ModulePolicyDescription(
|
||||||
attribute_replacement=decoder_attribute_replacement,
|
attribute_replacement=decoder_attribute_replacement,
|
||||||
)
|
)
|
||||||
|
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
|
||||||
self.append_or_create_method_replacement(
|
self.append_or_create_method_replacement(
|
||||||
description={
|
description={
|
||||||
"forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group),
|
"forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
|
||||||
},
|
},
|
||||||
policy=policy,
|
policy=policy,
|
||||||
target_key=attn_cls,
|
target_key=attn_cls,
|
||||||
)
|
)
|
||||||
|
if self.pipeline_stage_manager is None:
|
||||||
self.append_or_create_method_replacement(
|
self.append_or_create_method_replacement(
|
||||||
description={
|
description={
|
||||||
"forward": get_llama_seq_parallel_model_forward(
|
"forward": get_llama_flash_attention_model_forward(
|
||||||
|
self.shard_config,
|
||||||
sp_mode=sp_mode,
|
sp_mode=sp_mode,
|
||||||
sp_size=sp_size,
|
sp_size=sp_size,
|
||||||
sp_group=sp_group,
|
sp_group=sp_group,
|
||||||
|
@ -235,25 +208,6 @@ class LlamaPolicy(Policy):
|
||||||
target_key=LlamaModel,
|
target_key=LlamaModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
# use flash attention
|
|
||||||
if use_flash_attention:
|
|
||||||
self.append_or_create_method_replacement(
|
|
||||||
description={
|
|
||||||
"forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_group, sp_size),
|
|
||||||
},
|
|
||||||
policy=policy,
|
|
||||||
target_key=attn_cls,
|
|
||||||
)
|
|
||||||
if self.pipeline_stage_manager is None:
|
|
||||||
# replace llama model forward method
|
|
||||||
self.append_or_create_method_replacement(
|
|
||||||
description={
|
|
||||||
"forward": get_llama_model_forward_for_flash_attn(self.shard_config),
|
|
||||||
},
|
|
||||||
policy=policy,
|
|
||||||
target_key=LlamaModel,
|
|
||||||
)
|
|
||||||
|
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
|
|
|
@ -42,11 +42,13 @@ class MistralPolicy(Policy):
|
||||||
MistralDecoderLayer,
|
MistralDecoderLayer,
|
||||||
MistralFlashAttention2,
|
MistralFlashAttention2,
|
||||||
MistralModel,
|
MistralModel,
|
||||||
|
MistralSdpaAttention,
|
||||||
)
|
)
|
||||||
|
|
||||||
ATTN_IMPLEMENTATION = {
|
ATTN_IMPLEMENTATION = {
|
||||||
"eager": MistralAttention,
|
"eager": MistralAttention,
|
||||||
"flash_attention_2": MistralFlashAttention2,
|
"flash_attention_2": MistralFlashAttention2,
|
||||||
|
"sdpa": MistralSdpaAttention,
|
||||||
}
|
}
|
||||||
|
|
||||||
policy = {}
|
policy = {}
|
||||||
|
|
|
@ -25,6 +25,7 @@ class ChunkManager:
|
||||||
chunk_configuration,
|
chunk_configuration,
|
||||||
init_device: Optional[torch.device] = None,
|
init_device: Optional[torch.device] = None,
|
||||||
reuse_fp16_chunk: bool = True,
|
reuse_fp16_chunk: bool = True,
|
||||||
|
max_prefetch: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.device = init_device or get_accelerator().get_current_device()
|
self.device = init_device or get_accelerator().get_current_device()
|
||||||
self.dp_degree_chunk_size_dict: Dict[int, int] = dict()
|
self.dp_degree_chunk_size_dict: Dict[int, int] = dict()
|
||||||
|
@ -42,6 +43,7 @@ class ChunkManager:
|
||||||
# Whether model is accumulating gradients,
|
# Whether model is accumulating gradients,
|
||||||
self.accumulating_grads = False
|
self.accumulating_grads = False
|
||||||
self.overflow_counter = torch.tensor([0], dtype=torch.int, device=get_accelerator().get_current_device())
|
self.overflow_counter = torch.tensor([0], dtype=torch.int, device=get_accelerator().get_current_device())
|
||||||
|
self._prefetch_stream = get_accelerator().Stream() if max_prefetch else None
|
||||||
|
|
||||||
def register_tensor(
|
def register_tensor(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -21,6 +21,7 @@ def init_chunk_manager(
|
||||||
hidden_dim: Optional[int] = None,
|
hidden_dim: Optional[int] = None,
|
||||||
reuse_fp16_chunk: bool = True,
|
reuse_fp16_chunk: bool = True,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
|
max_prefetch: int = 0,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> ChunkManager:
|
) -> ChunkManager:
|
||||||
if hidden_dim:
|
if hidden_dim:
|
||||||
|
@ -51,9 +52,5 @@ def init_chunk_manager(
|
||||||
)
|
)
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
chunk_manager = ChunkManager(
|
chunk_manager = ChunkManager(config_dict, init_device, reuse_fp16_chunk=reuse_fp16_chunk, max_prefetch=max_prefetch)
|
||||||
config_dict,
|
|
||||||
init_device,
|
|
||||||
reuse_fp16_chunk=reuse_fp16_chunk,
|
|
||||||
)
|
|
||||||
return chunk_manager
|
return chunk_manager
|
||||||
|
|
|
@ -104,9 +104,7 @@ class GeminiDDP(ModelWrapper):
|
||||||
self.enable_gradient_accumulation = enable_gradient_accumulation
|
self.enable_gradient_accumulation = enable_gradient_accumulation
|
||||||
if chunk_config_dict is not None:
|
if chunk_config_dict is not None:
|
||||||
self.chunk_manager = ChunkManager(
|
self.chunk_manager = ChunkManager(
|
||||||
chunk_config_dict,
|
chunk_config_dict, chunk_init_device, reuse_fp16_chunk=reuse_fp16_chunk, max_prefetch=max_prefetch
|
||||||
chunk_init_device,
|
|
||||||
reuse_fp16_chunk=reuse_fp16_chunk,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# some ugly hotfix for the compatibility with Lightning
|
# some ugly hotfix for the compatibility with Lightning
|
||||||
|
@ -122,6 +120,7 @@ class GeminiDDP(ModelWrapper):
|
||||||
process_group=zero_group,
|
process_group=zero_group,
|
||||||
reuse_fp16_chunk=reuse_fp16_chunk,
|
reuse_fp16_chunk=reuse_fp16_chunk,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
|
max_prefetch=max_prefetch,
|
||||||
)
|
)
|
||||||
self.gemini_manager = GeminiManager(
|
self.gemini_manager = GeminiManager(
|
||||||
placement_policy,
|
placement_policy,
|
||||||
|
@ -451,6 +450,7 @@ class GeminiDDP(ModelWrapper):
|
||||||
chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True)
|
chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True)
|
||||||
if not (master_weights) or (enable_gradient_accumulation):
|
if not (master_weights) or (enable_gradient_accumulation):
|
||||||
chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True)
|
chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True)
|
||||||
|
return empty_grad
|
||||||
|
|
||||||
def zero_grad(self, set_to_none: bool = False) -> None:
|
def zero_grad(self, set_to_none: bool = False) -> None:
|
||||||
self.module.zero_grad(set_to_none=True)
|
self.module.zero_grad(set_to_none=True)
|
||||||
|
|
|
@ -5,6 +5,7 @@ from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from colossalai.accelerator import get_accelerator
|
||||||
from colossalai.tensor.param_op_hook import ColoParamOpHook
|
from colossalai.tensor.param_op_hook import ColoParamOpHook
|
||||||
from colossalai.utils import is_ddp_ignored
|
from colossalai.utils import is_ddp_ignored
|
||||||
from colossalai.zero.gemini import TensorState
|
from colossalai.zero.gemini import TensorState
|
||||||
|
@ -54,6 +55,16 @@ class GeminiZeROHook(ColoParamOpHook):
|
||||||
)
|
)
|
||||||
|
|
||||||
# prefetch
|
# prefetch
|
||||||
|
if self._gemini_manager.chunk_manager._prefetch_stream is not None:
|
||||||
|
# This is when prefetch happens the first time and there is no dist.Work to sync,
|
||||||
|
# there is possibility that the optimizer haven't finish computation on default stream,
|
||||||
|
# thus we might prefetch outdated chunks there.
|
||||||
|
#
|
||||||
|
# Other than that, self._gemini_manager.wait_chunks will have synced with default stream
|
||||||
|
# by calling dist.Work.wait() and this line makes no diff.
|
||||||
|
self._gemini_manager.chunk_manager._prefetch_stream.wait_stream(torch.cuda.current_stream())
|
||||||
|
|
||||||
|
with get_accelerator().stream(self._gemini_manager.chunk_manager._prefetch_stream):
|
||||||
for chunk in chunks_fetch_async:
|
for chunk in chunks_fetch_async:
|
||||||
maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True)
|
maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True)
|
||||||
if maybe_work is not None:
|
if maybe_work is not None:
|
||||||
|
|
|
@ -24,6 +24,7 @@
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
## 新闻
|
## 新闻
|
||||||
|
* [2024/06] [Open-Sora Continues Open Source: Generate Any 16-Second 720p HD Video with One Click, Model Weights Ready to Use](https://hpc-ai.com/blog/open-sora-from-hpc-ai-tech-team-continues-open-source-generate-any-16-second-720p-hd-video-with-one-click-model-weights-ready-to-use)
|
||||||
* [2024/05] [Large AI Models Inference Speed Doubled, Colossal-Inference Open Source Release](https://hpc-ai.com/blog/colossal-inference)
|
* [2024/05] [Large AI Models Inference Speed Doubled, Colossal-Inference Open Source Release](https://hpc-ai.com/blog/colossal-inference)
|
||||||
* [2024/04] [Open-Sora Unveils Major Upgrade: Embracing Open Source with Single-Shot 16-Second Video Generation and 720p Resolution](https://hpc-ai.com/blog/open-soras-comprehensive-upgrade-unveiled-embracing-16-second-video-generation-and-720p-resolution-in-open-source)
|
* [2024/04] [Open-Sora Unveils Major Upgrade: Embracing Open Source with Single-Shot 16-Second Video Generation and 720p Resolution](https://hpc-ai.com/blog/open-soras-comprehensive-upgrade-unveiled-embracing-16-second-video-generation-and-720p-resolution-in-open-source)
|
||||||
* [2024/04] [Most cost-effective solutions for inference, fine-tuning and pretraining, tailored to LLaMA3 series](https://hpc-ai.com/blog/most-cost-effective-solutions-for-inference-fine-tuning-and-pretraining-tailored-to-llama3-series)
|
* [2024/04] [Most cost-effective solutions for inference, fine-tuning and pretraining, tailored to LLaMA3 series](https://hpc-ai.com/blog/most-cost-effective-solutions-for-inference-fine-tuning-and-pretraining-tailored-to-llama3-series)
|
||||||
|
@ -31,10 +32,6 @@
|
||||||
* [2024/03] [Open-Sora: Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models](https://hpc-ai.com/blog/open-sora-v1.0)
|
* [2024/03] [Open-Sora: Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models](https://hpc-ai.com/blog/open-sora-v1.0)
|
||||||
* [2024/03] [Open-Sora:Sora Replication Solution with 46% Cost Reduction, Sequence Expansion to Nearly a Million](https://hpc-ai.com/blog/open-sora)
|
* [2024/03] [Open-Sora:Sora Replication Solution with 46% Cost Reduction, Sequence Expansion to Nearly a Million](https://hpc-ai.com/blog/open-sora)
|
||||||
* [2024/01] [Inference Performance Improved by 46%, Open Source Solution Breaks the Length Limit of LLM for Multi-Round Conversations](https://hpc-ai.com/blog/Colossal-AI-SwiftInfer)
|
* [2024/01] [Inference Performance Improved by 46%, Open Source Solution Breaks the Length Limit of LLM for Multi-Round Conversations](https://hpc-ai.com/blog/Colossal-AI-SwiftInfer)
|
||||||
* [2024/01] [Construct Refined 13B Private Model With Just $5000 USD, Upgraded Colossal-AI Llama-2 Open Source](https://hpc-ai.com/blog/colossal-llama-2-13b)
|
|
||||||
* [2023/11] [Enhanced MoE Parallelism, Open-source MoE Model Training Can Be 9 Times More Efficient](https://www.hpc-ai.tech/blog/enhanced-moe-parallelism-open-source-moe-model-training-can-be-9-times-more-efficient)
|
|
||||||
* [2023/09] [One Half-Day of Training Using a Few Hundred Dollars Yields Similar Results to Mainstream Large Models, Open-Source and Commercial-Free Domain-Specific LLM Solution](https://www.hpc-ai.tech/blog/one-half-day-of-training-using-a-few-hundred-dollars-yields-similar-results-to-mainstream-large-models-open-source-and-commercial-free-domain-specific-llm-solution)
|
|
||||||
* [2023/09] [70 Billion Parameter LLaMA2 Model Training Accelerated by 195%](https://www.hpc-ai.tech/blog/70b-llama2-training)
|
|
||||||
* [2023/07] [HPC-AI Tech Raises 22 Million USD in Series A Funding](https://www.hpc-ai.tech/blog/hpc-ai-tech-raises-22-million-usd-in-series-a-funding-to-fuel-team-expansion-and-business-growth)
|
* [2023/07] [HPC-AI Tech Raises 22 Million USD in Series A Funding](https://www.hpc-ai.tech/blog/hpc-ai-tech-raises-22-million-usd-in-series-a-funding-to-fuel-team-expansion-and-business-growth)
|
||||||
|
|
||||||
## 目录
|
## 目录
|
||||||
|
@ -127,13 +124,13 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
|
||||||
|
|
||||||
[Open-Sora](https://github.com/hpcaitech/Open-Sora):全面开源类Sora模型参数和所有训练细节
|
[Open-Sora](https://github.com/hpcaitech/Open-Sora):全面开源类Sora模型参数和所有训练细节
|
||||||
[[代码]](https://github.com/hpcaitech/Open-Sora)
|
[[代码]](https://github.com/hpcaitech/Open-Sora)
|
||||||
[[博客]](https://hpc-ai.com/blog/open-soras-comprehensive-upgrade-unveiled-embracing-16-second-video-generation-and-720p-resolution-in-open-source)
|
[[博客]](https://hpc-ai.com/blog/open-sora-from-hpc-ai-tech-team-continues-open-source-generate-any-16-second-720p-hd-video-with-one-click-model-weights-ready-to-use)
|
||||||
[[模型权重]](https://huggingface.co/hpcai-tech/Open-Sora)
|
[[模型权重]](https://github.com/hpcaitech/Open-Sora?tab=readme-ov-file#model-weights)
|
||||||
[[演示样例]](https://github.com/hpcaitech/Open-Sora?tab=readme-ov-file#-latest-demo)
|
[[演示样例]](https://github.com/hpcaitech/Open-Sora?tab=readme-ov-file#-latest-demo)
|
||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
<a href="https://www.bilibili.com/video/BV1dW421c7MN">
|
<a href="https://www.bilibili.com/video/BV1Fm421G7bV">
|
||||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/sora/sora-demo-cn.png" width="700" />
|
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/sora/opensora-v1.2.png" width="700" />
|
||||||
</a>
|
</a>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
|
|
@ -56,6 +56,7 @@
|
||||||
"features/pipeline_parallel",
|
"features/pipeline_parallel",
|
||||||
"features/nvme_offload",
|
"features/nvme_offload",
|
||||||
"features/lazy_init",
|
"features/lazy_init",
|
||||||
|
"features/distributed_optimizers",
|
||||||
"features/cluster_utils"
|
"features/cluster_utils"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
|
|
@ -14,12 +14,6 @@ Apart from the widely adopted Adam and SGD, many modern optimizers require layer
|
||||||
## Optimizers
|
## Optimizers
|
||||||
Adafactor is a first-order Adam variant using Non-negative Matrix Factorization(NMF) to reduce memory footprint. CAME improves by introducting a confidence matrix to correct NMF. GaLore further reduces memory by projecting gradients into a low-rank space and 8-bit block-wise quantization. Lamb allows huge batch sizes without lossing accuracy via layer-wise adaptive update bounded by the inverse of its Lipschiz constant.
|
Adafactor is a first-order Adam variant using Non-negative Matrix Factorization(NMF) to reduce memory footprint. CAME improves by introducting a confidence matrix to correct NMF. GaLore further reduces memory by projecting gradients into a low-rank space and 8-bit block-wise quantization. Lamb allows huge batch sizes without lossing accuracy via layer-wise adaptive update bounded by the inverse of its Lipschiz constant.
|
||||||
|
|
||||||
## API Reference
|
|
||||||
|
|
||||||
{{ autodoc:colossalai.nn.optimizer.distributed_adafactor.DistributedAdaFactor }}
|
|
||||||
{{ autodoc:colossalai.nn.optimizer.distributed_lamb.DistributedLamb }}
|
|
||||||
{{ autodoc:colossalai.nn.optimizer.distributed_galore.DistGaloreAwamW }}
|
|
||||||
{{ autodoc:colossalai.nn.optimizer.distributed_came.DistributedCAME }}
|
|
||||||
|
|
||||||
## Hands-On Practice
|
## Hands-On Practice
|
||||||
We now demonstrate how to use Distributed Adafactor with booster API combining Tensor Parallel and ZeRO 2 with 4 GPUs. **Note that even if you're not aware of distributed optimizers, the plugins automatically casts yours to the distributed version for convenience.**
|
We now demonstrate how to use Distributed Adafactor with booster API combining Tensor Parallel and ZeRO 2 with 4 GPUs. **Note that even if you're not aware of distributed optimizers, the plugins automatically casts yours to the distributed version for convenience.**
|
||||||
|
@ -140,3 +134,10 @@ optim = DistGaloreAwamW(
|
||||||
</table>
|
</table>
|
||||||
|
|
||||||
<!-- doc-test-command: colossalai run --nproc_per_node 4 distributed_optimizers.py -->
|
<!-- doc-test-command: colossalai run --nproc_per_node 4 distributed_optimizers.py -->
|
||||||
|
|
||||||
|
## API Reference
|
||||||
|
|
||||||
|
{{ autodoc:colossalai.nn.optimizer.distributed_adafactor.DistributedAdaFactor }}
|
||||||
|
{{ autodoc:colossalai.nn.optimizer.distributed_lamb.DistributedLamb }}
|
||||||
|
{{ autodoc:colossalai.nn.optimizer.distributed_galore.DistGaloreAwamW }}
|
||||||
|
{{ autodoc:colossalai.nn.optimizer.distributed_came.DistributedCAME }}
|
||||||
|
|
|
@ -13,12 +13,6 @@ Author: Wenxuan Tan, Junwen Duan, Renjie Mao
|
||||||
## 优化器
|
## 优化器
|
||||||
Adafactor 是一种首次采用非负矩阵分解(NMF)的 Adam 变体,用于减少内存占用。CAME 通过引入一个置信度矩阵来改进 NMF 的效果。GaLore 通过将梯度投影到低秩空间,并使用 8 位块状量化进一步减少内存占用。Lamb 允许使用巨大的批量大小而不失准确性,通过按其 Lipschitz 常数的倒数界定的逐层自适应更新实现
|
Adafactor 是一种首次采用非负矩阵分解(NMF)的 Adam 变体,用于减少内存占用。CAME 通过引入一个置信度矩阵来改进 NMF 的效果。GaLore 通过将梯度投影到低秩空间,并使用 8 位块状量化进一步减少内存占用。Lamb 允许使用巨大的批量大小而不失准确性,通过按其 Lipschitz 常数的倒数界定的逐层自适应更新实现
|
||||||
|
|
||||||
## API 参考
|
|
||||||
|
|
||||||
{{ autodoc:colossalai.nn.optimizer.distributed_adafactor.DistributedAdaFactor }}
|
|
||||||
{{ autodoc:colossalai.nn.optimizer.distributed_lamb.DistributedLamb }}
|
|
||||||
{{ autodoc:colossalai.nn.optimizer.distributed_galore.DistGaloreAwamW }}
|
|
||||||
{{ autodoc:colossalai.nn.optimizer.distributed_came.DistributedCAME }}
|
|
||||||
|
|
||||||
## 使用
|
## 使用
|
||||||
现在我们展示如何使用分布式 Adafactor 与 booster API 结合 Tensor Parallel 和 ZeRO 2。即使您不使用distributed optimizer,plugin 也会自动将optimizer转换为分布式版本以方便使用。
|
现在我们展示如何使用分布式 Adafactor 与 booster API 结合 Tensor Parallel 和 ZeRO 2。即使您不使用distributed optimizer,plugin 也会自动将optimizer转换为分布式版本以方便使用。
|
||||||
|
@ -137,3 +131,10 @@ optim = DistGaloreAwamW(
|
||||||
</table>
|
</table>
|
||||||
|
|
||||||
<!-- doc-test-command: colossalai run --nproc_per_node 4 distributed_optimizers.py -->
|
<!-- doc-test-command: colossalai run --nproc_per_node 4 distributed_optimizers.py -->
|
||||||
|
|
||||||
|
## API 参考
|
||||||
|
|
||||||
|
{{ autodoc:colossalai.nn.optimizer.distributed_adafactor.DistributedAdaFactor }}
|
||||||
|
{{ autodoc:colossalai.nn.optimizer.distributed_lamb.DistributedLamb }}
|
||||||
|
{{ autodoc:colossalai.nn.optimizer.distributed_galore.DistGaloreAwamW }}
|
||||||
|
{{ autodoc:colossalai.nn.optimizer.distributed_came.DistributedCAME }}
|
||||||
|
|
|
@ -43,5 +43,8 @@ colossalai run --nproc_per_node 2 llama_generation.py -m PATH_MODEL --drafter_mo
|
||||||
|
|
||||||
If you want to try the GLIDE model (glide-vicuna7b) as the drafter model with vicuna-7B, you could provide the GLIDE model path or model card as drafter model and enable the feature by
|
If you want to try the GLIDE model (glide-vicuna7b) as the drafter model with vicuna-7B, you could provide the GLIDE model path or model card as drafter model and enable the feature by
|
||||||
```python
|
```python
|
||||||
|
from colossalai.inference.modeling.models.glide_llama import GlideLlamaForCausalLM
|
||||||
|
drafter_model = GlideLlamaForCausalLM.from_pretrained(drafter_model_path_or_name)
|
||||||
|
...
|
||||||
engine.enable_spec_dec(drafter_model, use_glide_drafter=True)
|
engine.enable_spec_dec(drafter_model, use_glide_drafter=True)
|
||||||
```
|
```
|
||||||
|
|
|
@ -72,6 +72,7 @@ def main():
|
||||||
parser.add_argument("--offload_optim_frac", type=float, default=0.0, help="Offload optim fraction. Only for gemini")
|
parser.add_argument("--offload_optim_frac", type=float, default=0.0, help="Offload optim fraction. Only for gemini")
|
||||||
parser.add_argument("--offload_param_frac", type=float, default=0.0, help="Offload param fraction. Only for gemini")
|
parser.add_argument("--offload_param_frac", type=float, default=0.0, help="Offload param fraction. Only for gemini")
|
||||||
parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size")
|
parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size")
|
||||||
|
parser.add_argument("--sp", type=int, default=1, help="Sequence parallel size")
|
||||||
parser.add_argument("--extra_dp", type=int, default=1, help="Extra data parallel size, used for Gemini")
|
parser.add_argument("--extra_dp", type=int, default=1, help="Extra data parallel size, used for Gemini")
|
||||||
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size")
|
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size")
|
||||||
parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel")
|
parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel")
|
||||||
|
@ -174,6 +175,8 @@ def main():
|
||||||
tp_size=args.tp,
|
tp_size=args.tp,
|
||||||
pp_size=args.pp,
|
pp_size=args.pp,
|
||||||
zero_stage=args.zero,
|
zero_stage=args.zero,
|
||||||
|
sp_size=args.sp,
|
||||||
|
enable_sequence_parallelism=args.sp > 1,
|
||||||
enable_fused_normalization=torch.cuda.is_available(),
|
enable_fused_normalization=torch.cuda.is_available(),
|
||||||
enable_flash_attention=args.xformers,
|
enable_flash_attention=args.xformers,
|
||||||
microbatch_size=args.mbs,
|
microbatch_size=args.mbs,
|
||||||
|
|
|
@ -16,7 +16,7 @@ ray
|
||||||
sentencepiece
|
sentencepiece
|
||||||
google
|
google
|
||||||
protobuf
|
protobuf
|
||||||
transformers>=4.36.2,<4.40.0
|
transformers==4.39.3
|
||||||
peft>=0.7.1
|
peft>=0.7.1
|
||||||
bitsandbytes>=0.39.0
|
bitsandbytes>=0.39.0
|
||||||
rpyc==6.0.0
|
rpyc==6.0.0
|
||||||
|
|
|
@ -22,3 +22,9 @@ try:
|
||||||
from .qwen2 import *
|
from .qwen2 import *
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("This version of transformers doesn't support qwen2.")
|
print("This version of transformers doesn't support qwen2.")
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .command import *
|
||||||
|
except ImportError:
|
||||||
|
print("This version of transformers doesn't support Command-R.")
|
||||||
|
|
|
@ -33,22 +33,6 @@ loss_fn_for_chatglm_model = lambda x: torch.nn.functional.mse_loss(
|
||||||
)
|
)
|
||||||
loss_fn = lambda x: x["loss"]
|
loss_fn = lambda x: x["loss"]
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
|
||||||
"THUDM/chatglm2-6b",
|
|
||||||
trust_remote_code=True,
|
|
||||||
num_layers=2,
|
|
||||||
padded_vocab_size=65024,
|
|
||||||
hidden_size=64,
|
|
||||||
ffn_hidden_size=214,
|
|
||||||
num_attention_heads=8,
|
|
||||||
kv_channels=16,
|
|
||||||
rmsnorm=True,
|
|
||||||
original_rope=True,
|
|
||||||
use_cache=True,
|
|
||||||
multi_query_attention=False,
|
|
||||||
torch_dtype=torch.float32,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
infer_config = AutoConfig.from_pretrained(
|
infer_config = AutoConfig.from_pretrained(
|
||||||
"THUDM/chatglm2-6b",
|
"THUDM/chatglm2-6b",
|
||||||
|
@ -68,6 +52,21 @@ infer_config = AutoConfig.from_pretrained(
|
||||||
|
|
||||||
|
|
||||||
def init_chatglm():
|
def init_chatglm():
|
||||||
|
config = AutoConfig.from_pretrained(
|
||||||
|
"THUDM/chatglm2-6b",
|
||||||
|
trust_remote_code=True,
|
||||||
|
num_layers=2,
|
||||||
|
padded_vocab_size=65024,
|
||||||
|
hidden_size=64,
|
||||||
|
ffn_hidden_size=214,
|
||||||
|
num_attention_heads=8,
|
||||||
|
kv_channels=16,
|
||||||
|
rmsnorm=True,
|
||||||
|
original_rope=True,
|
||||||
|
use_cache=True,
|
||||||
|
multi_query_attention=False,
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
)
|
||||||
model = AutoModelForCausalLM.from_config(config, empty_init=False, trust_remote_code=True)
|
model = AutoModelForCausalLM.from_config(config, empty_init=False, trust_remote_code=True)
|
||||||
for m in model.modules():
|
for m in model.modules():
|
||||||
if m.__class__.__name__ == "RMSNorm":
|
if m.__class__.__name__ == "RMSNorm":
|
||||||
|
|
|
@ -0,0 +1,79 @@
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
from ..registry import ModelAttribute, model_zoo
|
||||||
|
|
||||||
|
try:
|
||||||
|
from transformers import CohereConfig
|
||||||
|
|
||||||
|
HAS_COMMAND = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_COMMAND = False
|
||||||
|
|
||||||
|
if HAS_COMMAND:
|
||||||
|
# ===============================
|
||||||
|
# Register Command-R
|
||||||
|
# ===============================
|
||||||
|
|
||||||
|
def data_gen():
|
||||||
|
input_ids = torch.Tensor(
|
||||||
|
[
|
||||||
|
[1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082],
|
||||||
|
[1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082],
|
||||||
|
]
|
||||||
|
).long()
|
||||||
|
|
||||||
|
attention_mask = torch.Tensor(
|
||||||
|
[
|
||||||
|
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||||
|
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||||
|
]
|
||||||
|
).long()
|
||||||
|
|
||||||
|
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
|
||||||
|
# label is needed for casual lm
|
||||||
|
def data_gen_for_casual_lm():
|
||||||
|
data = data_gen()
|
||||||
|
labels = data["input_ids"].clone()
|
||||||
|
data["labels"] = labels
|
||||||
|
return data
|
||||||
|
|
||||||
|
# transform the output to a dict
|
||||||
|
output_transform_fn = lambda x: x
|
||||||
|
|
||||||
|
# function to get the loss
|
||||||
|
loss_fn = lambda output: output["last_hidden_state"].mean()
|
||||||
|
loss_fn_for_casual_lm = lambda output: output["loss"]
|
||||||
|
loss_fn_for_seq_classification = lambda output: output["logits"].mean()
|
||||||
|
|
||||||
|
config = CohereConfig(
|
||||||
|
num_hidden_layers=8,
|
||||||
|
hidden_size=32,
|
||||||
|
intermediate_size=64,
|
||||||
|
num_attention_heads=4,
|
||||||
|
max_position_embeddings=128,
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(config, "pad_token_id"):
|
||||||
|
config.pad_token_id = config.eos_token_id
|
||||||
|
|
||||||
|
# register the following models
|
||||||
|
# transformers.CohereModel,
|
||||||
|
# transformers.CohereForCausalLM,
|
||||||
|
model_zoo.register(
|
||||||
|
name="transformers_command",
|
||||||
|
model_fn=lambda: transformers.CohereModel(config),
|
||||||
|
data_gen_fn=data_gen,
|
||||||
|
output_transform_fn=output_transform_fn,
|
||||||
|
loss_fn=loss_fn,
|
||||||
|
model_attribute=ModelAttribute(has_control_flow=True),
|
||||||
|
)
|
||||||
|
model_zoo.register(
|
||||||
|
name="transformers_command_for_casual_lm",
|
||||||
|
model_fn=lambda: transformers.CohereForCausalLM(config),
|
||||||
|
data_gen_fn=data_gen_for_casual_lm,
|
||||||
|
output_transform_fn=output_transform_fn,
|
||||||
|
loss_fn=loss_fn_for_casual_lm,
|
||||||
|
model_attribute=ModelAttribute(has_control_flow=True),
|
||||||
|
)
|
|
@ -4,7 +4,7 @@ import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes
|
from colossalai.inference.utils import get_alibi_slopes
|
||||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from tests.test_infer.test_kernels.triton.test_context_attn_unpad import generate_alibi_mask
|
from tests.test_infer.test_kernels.triton.test_context_attn_unpad import generate_alibi_mask
|
||||||
|
|
|
@ -26,7 +26,7 @@ def prepare_data(
|
||||||
num_tokens = torch.sum(context_lengths).item()
|
num_tokens = torch.sum(context_lengths).item()
|
||||||
|
|
||||||
max_seq_len_in_batch = context_lengths.max()
|
max_seq_len_in_batch = context_lengths.max()
|
||||||
cu_seqlens = F.pad(torch.cumsum(context_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
|
cu_seqlens = F.pad(torch.cumsum(context_lengths, dim=0, dtype=torch.int32), (1, 0))
|
||||||
|
|
||||||
kv_size = (num_tokens, num_kv_heads, HEAD_DIM)
|
kv_size = (num_tokens, num_kv_heads, HEAD_DIM)
|
||||||
key = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
key = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
||||||
|
|
|
@ -28,15 +28,22 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, K_H, D, dtype):
|
||||||
torch.manual_seed(10)
|
torch.manual_seed(10)
|
||||||
TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN
|
TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN
|
||||||
# our crafted op equals to Transformers
|
# our crafted op equals to Transformers
|
||||||
x0 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D, dtype=dtype)
|
x0 = torch.randn(BATCH_SIZE, H, SEQ_LEN, D, dtype=dtype)
|
||||||
x1 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D, dtype=dtype)
|
x1 = torch.randn(BATCH_SIZE, H, SEQ_LEN, D, dtype=dtype)
|
||||||
|
|
||||||
|
position_ids = torch.arange(TOTAL_TOKENS).reshape((BATCH_SIZE, SEQ_LEN))
|
||||||
|
|
||||||
emb = LlamaRotaryEmbedding(D)
|
emb = LlamaRotaryEmbedding(D)
|
||||||
cos, sin = emb(x0, TOTAL_TOKENS)
|
|
||||||
|
cos, sin = emb(x0, position_ids)
|
||||||
|
embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin)
|
||||||
|
cos = cos.reshape((TOTAL_TOKENS, -1))
|
||||||
|
sin = sin.reshape((TOTAL_TOKENS, -1))
|
||||||
cos_2 = cos[:, : D // 2]
|
cos_2 = cos[:, : D // 2]
|
||||||
sin_2 = sin[:, : D // 2]
|
sin_2 = sin[:, : D // 2]
|
||||||
position_ids = torch.arange(TOTAL_TOKENS)
|
x2 = x0.transpose(1, 2).reshape(TOTAL_TOKENS, H, D)
|
||||||
embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin, position_ids)
|
embd_stimulated_x = torch_rotary_emb(x2, cos_2, sin_2)
|
||||||
embd_stimulated_x = torch_rotary_emb(x0, cos_2, sin_2)
|
embd_stimulated_x = embd_stimulated_x.reshape((BATCH_SIZE, SEQ_LEN, H, D)).transpose(1, 2)
|
||||||
assert torch.allclose(embd_x0, embd_stimulated_x)
|
assert torch.allclose(embd_x0, embd_stimulated_x)
|
||||||
|
|
||||||
# create data
|
# create data
|
||||||
|
|
|
@ -2,7 +2,7 @@ import pytest
|
||||||
import torch
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes
|
from colossalai.inference.utils import get_alibi_slopes
|
||||||
from colossalai.kernel.triton import context_attention_unpadded
|
from colossalai.kernel.triton import context_attention_unpadded
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from tests.test_infer.test_kernels.triton.kernel_utils import (
|
from tests.test_infer.test_kernels.triton.kernel_utils import (
|
||||||
|
|
|
@ -3,7 +3,7 @@ import pytest
|
||||||
import torch
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes
|
from colossalai.inference.utils import get_alibi_slopes
|
||||||
from colossalai.kernel.triton import flash_decoding_attention
|
from colossalai.kernel.triton import flash_decoding_attention
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from tests.test_infer.test_kernels.triton.kernel_utils import (
|
from tests.test_infer.test_kernels.triton.kernel_utils import (
|
||||||
|
|
|
@ -43,15 +43,19 @@ def torch_rotary_emb(x, cos, sin):
|
||||||
def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype, use_new_kcache_layout):
|
def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype, use_new_kcache_layout):
|
||||||
TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN
|
TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN
|
||||||
# our crafted op equals to Transformers
|
# our crafted op equals to Transformers
|
||||||
x0 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D)
|
x0 = torch.randn(BATCH_SIZE, H, SEQ_LEN, D, dtype=dtype)
|
||||||
x1 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D)
|
x1 = torch.randn(BATCH_SIZE, H, SEQ_LEN, D, dtype=dtype)
|
||||||
emb = LlamaRotaryEmbedding(D)
|
emb = LlamaRotaryEmbedding(D)
|
||||||
cos, sin = emb(x0, TOTAL_TOKENS)
|
position_ids = torch.arange(TOTAL_TOKENS).reshape((BATCH_SIZE, SEQ_LEN))
|
||||||
|
cos, sin = emb(x0, position_ids)
|
||||||
|
embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin)
|
||||||
|
cos = cos.reshape((TOTAL_TOKENS, -1))
|
||||||
|
sin = sin.reshape((TOTAL_TOKENS, -1))
|
||||||
cos_2 = cos[:, :32]
|
cos_2 = cos[:, :32]
|
||||||
sin_2 = sin[:, :32]
|
sin_2 = sin[:, :32]
|
||||||
position_ids = torch.arange(TOTAL_TOKENS)
|
x2 = x0.transpose(1, 2).reshape(TOTAL_TOKENS, H, D)
|
||||||
embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin, position_ids)
|
embd_stimulated_x = torch_rotary_emb(x2, cos_2, sin_2)
|
||||||
embd_stimulated_x = torch_rotary_emb(x0, cos_2, sin_2)
|
embd_stimulated_x = embd_stimulated_x.reshape((BATCH_SIZE, SEQ_LEN, H, D)).transpose(1, 2)
|
||||||
assert torch.allclose(embd_x0, embd_stimulated_x)
|
assert torch.allclose(embd_x0, embd_stimulated_x)
|
||||||
|
|
||||||
# create data
|
# create data
|
||||||
|
|
|
@ -55,7 +55,7 @@ def check_inference_engine(use_engine=False, do_sample=False, use_cuda_kernel=Fa
|
||||||
assert inference_engine.generation_config.max_new_tokens == output_len
|
assert inference_engine.generation_config.max_new_tokens == output_len
|
||||||
inference_engine.add_request(prompts=inputs)
|
inference_engine.add_request(prompts=inputs)
|
||||||
assert inference_engine.request_handler._has_waiting()
|
assert inference_engine.request_handler._has_waiting()
|
||||||
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
|
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k, max_new_tokens=output_len)
|
||||||
outputs = inference_engine.generate(generation_config=generation_config)
|
outputs = inference_engine.generate(generation_config=generation_config)
|
||||||
else:
|
else:
|
||||||
if prompt_template:
|
if prompt_template:
|
||||||
|
|
|
@ -0,0 +1,161 @@
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.multiprocessing import Manager
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, LlamaForCausalLM, LlamaTokenizer
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
import colossalai.inference.modeling.policy as policy
|
||||||
|
from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig
|
||||||
|
from colossalai.inference.core.engine import InferenceEngine
|
||||||
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
# NOTE: To test a model with the inference engine, you need to provide the path to your
|
||||||
|
# local pretrained model weights in the MODEL_MAP dictionary
|
||||||
|
MODEL_MAP = {
|
||||||
|
"baichuan": {
|
||||||
|
"model": AutoModelForCausalLM,
|
||||||
|
"tokenizer": AutoTokenizer,
|
||||||
|
"policy": policy.NoPaddingBaichuanModelInferPolicy,
|
||||||
|
"model_name_or_path": "baichuan-inc/Baichuan2-13B-Base", # provide the path to local model weights
|
||||||
|
},
|
||||||
|
"llama": {
|
||||||
|
"model": LlamaForCausalLM,
|
||||||
|
"tokenizer": LlamaTokenizer,
|
||||||
|
"policy": policy.NoPaddingLlamaModelInferPolicy,
|
||||||
|
"model_name_or_path": "meta-llama/Llama-2-70b-hf",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
MODELS_TO_TEST = ["llama", "baichuan"] # Specify the models to test
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize("model", MODELS_TO_TEST)
|
||||||
|
@parameterize("prompt_template", [None, "model_specific"])
|
||||||
|
@parameterize("do_sample", [False])
|
||||||
|
@parameterize("use_cuda_kernel", [True])
|
||||||
|
@pytest.mark.largedist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_model(model, prompt_template, do_sample, use_cuda_kernel):
|
||||||
|
model_path = MODEL_MAP[model]["model_name_or_path"]
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
pytest.skip(
|
||||||
|
f"There is no local model address included for {model}, please replace this address with a valid one."
|
||||||
|
)
|
||||||
|
|
||||||
|
if prompt_template == "model_specific":
|
||||||
|
prompt_template = model
|
||||||
|
|
||||||
|
model_config = MODEL_MAP[model]
|
||||||
|
|
||||||
|
kwargs1 = {
|
||||||
|
"model": model,
|
||||||
|
"use_engine": True,
|
||||||
|
"prompt_template": prompt_template,
|
||||||
|
"do_sample": do_sample,
|
||||||
|
"policy": model_config["policy"](),
|
||||||
|
"use_cuda_kernel": use_cuda_kernel,
|
||||||
|
}
|
||||||
|
|
||||||
|
kwargs2 = {
|
||||||
|
"model": model,
|
||||||
|
"use_engine": False,
|
||||||
|
"prompt_template": prompt_template,
|
||||||
|
"do_sample": do_sample,
|
||||||
|
"policy": None,
|
||||||
|
"use_cuda_kernel": use_cuda_kernel,
|
||||||
|
}
|
||||||
|
|
||||||
|
colossal_tp_1_output = run_engine(1, **kwargs1)
|
||||||
|
colossal_tp_2_output = run_engine(2, **kwargs1)
|
||||||
|
transformer_tp_1_output = run_engine(1, **kwargs2)
|
||||||
|
|
||||||
|
for s1, s2, s3 in zip(colossal_tp_1_output, colossal_tp_2_output, transformer_tp_1_output):
|
||||||
|
assert s1 == s3, f"\nColossalAI TP=1 Output: {s1}\nTransformers Output: {s3}"
|
||||||
|
assert s1 == s2, f"\nColossalAI TP=1 Output: {s1}\nColossalAI TP=2 Output: {s2}"
|
||||||
|
|
||||||
|
|
||||||
|
def run_engine(world_size, **kwargs):
|
||||||
|
manager = Manager()
|
||||||
|
result_list = manager.list([-1] * world_size) # Create a shared list
|
||||||
|
spawn(run_dist, world_size, func_to_run=_run_engine, ret=result_list, **kwargs)
|
||||||
|
return result_list[0]
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs):
|
||||||
|
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||||
|
|
||||||
|
if ret:
|
||||||
|
ret[rank] = func_to_run(**kwargs)
|
||||||
|
else:
|
||||||
|
func_to_run(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_engine(model, use_engine=False, do_sample=False, use_cuda_kernel=False, prompt_template=None, policy=None):
|
||||||
|
setup_seed(20)
|
||||||
|
model_config = MODEL_MAP[model]
|
||||||
|
model_name_or_path = model_config["model_name_or_path"]
|
||||||
|
tokenizer = model_config["tokenizer"].from_pretrained(model_name_or_path, use_fast=False, trust_remote_code=True)
|
||||||
|
model = model_config["model"].from_pretrained(model_name_or_path, trust_remote_code=True).half().cuda()
|
||||||
|
model = model.eval()
|
||||||
|
|
||||||
|
inputs = [
|
||||||
|
"Introduce some landmarks in Paris:",
|
||||||
|
]
|
||||||
|
|
||||||
|
output_len = 38
|
||||||
|
|
||||||
|
if do_sample:
|
||||||
|
top_p = 0.5
|
||||||
|
top_k = 50
|
||||||
|
else:
|
||||||
|
top_p = None
|
||||||
|
top_k = None
|
||||||
|
|
||||||
|
if use_engine:
|
||||||
|
inference_config = InferenceConfig(
|
||||||
|
max_output_len=output_len,
|
||||||
|
prompt_template=prompt_template,
|
||||||
|
use_cuda_kernel=use_cuda_kernel,
|
||||||
|
tp_size=dist.get_world_size(),
|
||||||
|
)
|
||||||
|
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True, model_policy=policy)
|
||||||
|
assert inference_engine.generation_config.max_new_tokens == output_len
|
||||||
|
inference_engine.add_request(prompts=inputs)
|
||||||
|
assert inference_engine.request_handler._has_waiting()
|
||||||
|
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k, max_new_tokens=output_len)
|
||||||
|
outputs = inference_engine.generate(generation_config=generation_config)
|
||||||
|
else:
|
||||||
|
if prompt_template:
|
||||||
|
# apply prompt template
|
||||||
|
inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs]
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
|
inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
|
||||||
|
inputs = inputs.cuda()
|
||||||
|
generation_config = GenerationConfig(
|
||||||
|
do_sample=do_sample,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
max_new_tokens=output_len,
|
||||||
|
)
|
||||||
|
outputs = model.generate(inputs, generation_config=generation_config)
|
||||||
|
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
def setup_seed(seed):
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.random.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_model()
|
|
@ -0,0 +1,322 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.testing import assert_close
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.shardformer import PipelineGradientCheckpointConfig
|
||||||
|
from colossalai.shardformer.layer.utils import Randomizer
|
||||||
|
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||||
|
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
from tests.test_shardformer.test_model._utils import (
|
||||||
|
build_model_from_hybrid_plugin,
|
||||||
|
check_all_grad_tensors,
|
||||||
|
check_loss,
|
||||||
|
check_output_hidden_state,
|
||||||
|
check_weight,
|
||||||
|
get_grad_tensors_for_check,
|
||||||
|
run_forward_backward_with_hybrid_plugin,
|
||||||
|
unwrap_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
|
||||||
|
enable_gradient_checkpointing = test_config.pop("enable_gradient_checkpointing", False)
|
||||||
|
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
|
||||||
|
model_fn, loss_fn, test_config
|
||||||
|
)
|
||||||
|
if enable_gradient_checkpointing:
|
||||||
|
# org_model.gradient_checkpointing_enable()
|
||||||
|
sharded_model.unwrap().gradient_checkpointing_enable()
|
||||||
|
|
||||||
|
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
|
||||||
|
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
|
||||||
|
)
|
||||||
|
|
||||||
|
stage_manager = booster.plugin.stage_manager
|
||||||
|
tp_group = booster.plugin.tp_group
|
||||||
|
|
||||||
|
# unwrap model
|
||||||
|
command_model = unwrap_model(org_model, "CohereModel", "model")
|
||||||
|
shard_command_model = unwrap_model(sharded_model, "CohereModel", "model")
|
||||||
|
|
||||||
|
row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"]
|
||||||
|
col_layer_for_check = ["layers[0].self_attn.o_proj"]
|
||||||
|
# Here we check the grad of layernorm because an all-reduce operation should be performed during sequence parallelism
|
||||||
|
norm_layer_for_check = ["layers[0].input_layernorm", "layers[1].input_layernorm"]
|
||||||
|
|
||||||
|
# During pipeline parallelism, we cannot get the grad of norm layer during first stage, so we only check this when pp is not enbaled
|
||||||
|
if stage_manager is None:
|
||||||
|
norm_layer_for_check.append("norm")
|
||||||
|
|
||||||
|
# Check the grad when using ZeRO-1 and ZeRO-2
|
||||||
|
if (
|
||||||
|
booster.plugin.zero_stage in [1, 2]
|
||||||
|
and booster.plugin.shard_config.enable_sequence_parallelism
|
||||||
|
and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all"
|
||||||
|
):
|
||||||
|
for p1, p2 in zip(command_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]):
|
||||||
|
working_p = sharded_optimizer._param_store.master_to_working_param[id(p2)]
|
||||||
|
grads = sharded_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(working_p))
|
||||||
|
grad_index = (
|
||||||
|
0 if sharded_optimizer._grad_store._partition_grads else sharded_optimizer._bucket_store.zero_local_rank
|
||||||
|
)
|
||||||
|
grad = grads[grad_index]
|
||||||
|
sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()]
|
||||||
|
assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False)
|
||||||
|
|
||||||
|
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
|
||||||
|
grads_to_check = {}
|
||||||
|
if (stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True)) and booster.plugin.zero_stage == 0:
|
||||||
|
if test_config["precision"] == "fp32":
|
||||||
|
atol, rtol = 1e-6, 1e-4
|
||||||
|
else:
|
||||||
|
atol, rtol = 5e-3, 5e-3
|
||||||
|
row_layer_grads = get_grad_tensors_for_check(
|
||||||
|
command_model,
|
||||||
|
shard_command_model,
|
||||||
|
row_layer_for_check,
|
||||||
|
tp_group,
|
||||||
|
atol=atol,
|
||||||
|
rtol=rtol,
|
||||||
|
dim=0,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
col_layer_grads = get_grad_tensors_for_check(
|
||||||
|
command_model,
|
||||||
|
shard_command_model,
|
||||||
|
col_layer_for_check,
|
||||||
|
tp_group,
|
||||||
|
atol=atol,
|
||||||
|
rtol=rtol,
|
||||||
|
dim=1,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
norm_layer_grads = get_grad_tensors_for_check(
|
||||||
|
command_model,
|
||||||
|
shard_command_model,
|
||||||
|
norm_layer_for_check,
|
||||||
|
tp_group,
|
||||||
|
atol=atol,
|
||||||
|
rtol=rtol,
|
||||||
|
dim=1,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
grads_to_check.update(col_layer_grads)
|
||||||
|
grads_to_check.update(row_layer_grads)
|
||||||
|
grads_to_check.update(norm_layer_grads)
|
||||||
|
|
||||||
|
# optimizer executes step
|
||||||
|
org_optimizer.step()
|
||||||
|
sharded_optimizer.step()
|
||||||
|
|
||||||
|
# check last hidden state & loss
|
||||||
|
if stage_manager is None or stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
|
if test_config["precision"] == "fp32":
|
||||||
|
atol, rtol = 1e-5, 1e-3
|
||||||
|
else:
|
||||||
|
atol, rtol = 5e-3, 5e-3
|
||||||
|
|
||||||
|
if org_model.__class__.__name__ == "CohereModel":
|
||||||
|
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
# check weights
|
||||||
|
if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
|
if test_config["precision"] == "fp32":
|
||||||
|
atol, rtol = 5e-4, 1e-3
|
||||||
|
else:
|
||||||
|
atol, rtol = 5e-3, 5e-3
|
||||||
|
check_weight(
|
||||||
|
command_model,
|
||||||
|
shard_command_model,
|
||||||
|
col_layer_for_check,
|
||||||
|
tp_group,
|
||||||
|
atol=atol,
|
||||||
|
rtol=rtol,
|
||||||
|
dim=1,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# check grads
|
||||||
|
check_all_grad_tensors(grads_to_check)
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize(
|
||||||
|
"test_config",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 1,
|
||||||
|
"num_microbatches": 1,
|
||||||
|
"enable_sequence_parallelism": True,
|
||||||
|
"sequence_parallelism_mode": "ring",
|
||||||
|
"enable_flash_attention": True,
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"zero_stage": 2,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 4,
|
||||||
|
"pp_size": 1,
|
||||||
|
"num_microbatches": 1,
|
||||||
|
"enable_sequence_parallelism": True,
|
||||||
|
"sequence_parallelism_mode": "split_gather",
|
||||||
|
"enable_flash_attention": False,
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 1,
|
||||||
|
"pp_size": 1,
|
||||||
|
"sp_size": 2,
|
||||||
|
"num_microbatches": 1,
|
||||||
|
"enable_sequence_parallelism": True,
|
||||||
|
"sequence_parallelism_mode": "all_to_all",
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"zero_stage": 2,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 2,
|
||||||
|
"num_microbatches": 2,
|
||||||
|
"enable_all_optimization": True,
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
"enable_gradient_checkpointing": True,
|
||||||
|
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(gradient_checkpointing_ratio=0.5),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 1,
|
||||||
|
"pp_size": 2,
|
||||||
|
"num_microbatches": 4,
|
||||||
|
"use_lazy_init": False,
|
||||||
|
"precision": "fp32",
|
||||||
|
"enable_gradient_checkpointing": True,
|
||||||
|
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 1,
|
||||||
|
"enable_all_optimization": True,
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"zero_stage": 2,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 1,
|
||||||
|
"pp_size": 2,
|
||||||
|
"num_microbatches": 2,
|
||||||
|
"enable_all_optimization": True,
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"zero_stage": 1,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def run_command_test(test_config):
|
||||||
|
sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm")
|
||||||
|
|
||||||
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
|
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||||
|
|
||||||
|
clear_layout_converter()
|
||||||
|
Randomizer.reset_index()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize(
|
||||||
|
"test_config",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 2,
|
||||||
|
"num_microbatches": 4,
|
||||||
|
"enable_all_optimization": False,
|
||||||
|
"use_lazy_init": False,
|
||||||
|
"precision": "fp32",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 2,
|
||||||
|
"num_microbatches": 4,
|
||||||
|
"enable_all_optimization": False,
|
||||||
|
"use_lazy_init": False,
|
||||||
|
"precision": "fp16",
|
||||||
|
"zero_stage": 1,
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 2,
|
||||||
|
"pp_style": "interleaved",
|
||||||
|
"num_model_chunks": 2,
|
||||||
|
"num_microbatches": 4,
|
||||||
|
"enable_all_optimization": False,
|
||||||
|
"precision": "fp16",
|
||||||
|
"zero_stage": 1,
|
||||||
|
"initial_scale": 1,
|
||||||
|
"enable_gradient_checkpointing": True,
|
||||||
|
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(
|
||||||
|
num_ckpt_layers_per_stage=[0, 1, 2, 2],
|
||||||
|
),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def run_command_3d_test(test_config):
|
||||||
|
sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm")
|
||||||
|
|
||||||
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
|
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||||
|
|
||||||
|
clear_layout_converter()
|
||||||
|
Randomizer.reset_index()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
def check_command(rank, world_size, port):
|
||||||
|
disable_existing_loggers()
|
||||||
|
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
|
run_command_test()
|
||||||
|
|
||||||
|
|
||||||
|
def check_command_3d(rank, world_size, port):
|
||||||
|
disable_existing_loggers()
|
||||||
|
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
|
run_command_3d_test()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
@clear_cache_before_run()
|
||||||
|
def test_command():
|
||||||
|
spawn(check_command, 4)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.largedist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
@clear_cache_before_run()
|
||||||
|
def test_command_3d():
|
||||||
|
spawn(check_command_3d, 8)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_command()
|
||||||
|
test_command_3d()
|
|
@ -120,9 +120,20 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
atol, rtol = 1e-4, 1e-3
|
atol, rtol = 1e-4, 1e-3
|
||||||
else:
|
else:
|
||||||
atol, rtol = 5e-3, 5e-3
|
atol, rtol = 5e-3, 5e-3
|
||||||
|
try:
|
||||||
check_weight(
|
check_weight(
|
||||||
llama_model, shard_llama_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
|
llama_model,
|
||||||
|
shard_llama_model,
|
||||||
|
col_layer_for_check,
|
||||||
|
tp_group,
|
||||||
|
atol=atol,
|
||||||
|
rtol=rtol,
|
||||||
|
dim=1,
|
||||||
|
verbose=False,
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed config: {test_config}")
|
||||||
|
raise e
|
||||||
|
|
||||||
# check grads
|
# check grads
|
||||||
check_all_grad_tensors(grads_to_check)
|
check_all_grad_tensors(grads_to_check)
|
||||||
|
@ -133,9 +144,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
@parameterize(
|
@parameterize(
|
||||||
"test_config",
|
"test_config",
|
||||||
[
|
[
|
||||||
{
|
{ # Test ring + Flash attention
|
||||||
"tp_size": 2,
|
"tp_size": 2,
|
||||||
"pp_size": 1,
|
"pp_size": 1,
|
||||||
|
"sp_size": 2,
|
||||||
"num_microbatches": 1,
|
"num_microbatches": 1,
|
||||||
"enable_sequence_parallelism": True,
|
"enable_sequence_parallelism": True,
|
||||||
"sequence_parallelism_mode": "ring",
|
"sequence_parallelism_mode": "ring",
|
||||||
|
@ -145,14 +157,16 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
"initial_scale": 1,
|
"initial_scale": 1,
|
||||||
},
|
},
|
||||||
{
|
{ # Ulysess + Flash attention
|
||||||
"tp_size": 4,
|
"tp_size": 1,
|
||||||
"pp_size": 1,
|
"pp_size": 2,
|
||||||
"num_microbatches": 1,
|
"sp_size": 2,
|
||||||
|
"num_microbatches": 2,
|
||||||
"enable_sequence_parallelism": True,
|
"enable_sequence_parallelism": True,
|
||||||
"sequence_parallelism_mode": "split_gather",
|
"sequence_parallelism_mode": "all_to_all",
|
||||||
"enable_flash_attention": False,
|
"enable_flash_attention": True,
|
||||||
"use_lazy_init": True,
|
"use_lazy_init": True,
|
||||||
|
"zero_stage": 1,
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
"initial_scale": 1,
|
"initial_scale": 1,
|
||||||
},
|
},
|
||||||
|
@ -164,7 +178,18 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
"enable_sequence_parallelism": True,
|
"enable_sequence_parallelism": True,
|
||||||
"sequence_parallelism_mode": "all_to_all",
|
"sequence_parallelism_mode": "all_to_all",
|
||||||
"use_lazy_init": True,
|
"use_lazy_init": True,
|
||||||
"zero_stage": 2,
|
"zero_stage": 1,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 4,
|
||||||
|
"pp_size": 1,
|
||||||
|
"num_microbatches": 1,
|
||||||
|
"enable_sequence_parallelism": True,
|
||||||
|
"sequence_parallelism_mode": "split_gather",
|
||||||
|
"enable_flash_attention": False,
|
||||||
|
"use_lazy_init": True,
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
"initial_scale": 1,
|
"initial_scale": 1,
|
||||||
},
|
},
|
||||||
|
@ -213,7 +238,11 @@ def run_llama_test(test_config):
|
||||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
|
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
|
||||||
|
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
|
try:
|
||||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed config: {test_config}")
|
||||||
|
raise e
|
||||||
|
|
||||||
clear_layout_converter()
|
clear_layout_converter()
|
||||||
Randomizer.reset_index()
|
Randomizer.reset_index()
|
||||||
|
@ -263,7 +292,11 @@ def run_llama_3d_test(test_config):
|
||||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
|
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
|
||||||
|
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
|
try:
|
||||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed config: {test_config}")
|
||||||
|
raise e
|
||||||
|
|
||||||
clear_layout_converter()
|
clear_layout_converter()
|
||||||
Randomizer.reset_index()
|
Randomizer.reset_index()
|
||||||
|
|
|
@ -217,6 +217,7 @@ def check_qwen2_3d(rank, world_size, port):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(transformers.__version__ < "4.39.1", reason="Requires transformers version 4.39.1 or later")
|
@pytest.mark.skipif(transformers.__version__ < "4.39.1", reason="Requires transformers version 4.39.1 or later")
|
||||||
|
@pytest.mark.dist
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
@clear_cache_before_run()
|
@clear_cache_before_run()
|
||||||
def test_qwen2():
|
def test_qwen2():
|
||||||
|
@ -224,6 +225,7 @@ def test_qwen2():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(transformers.__version__ < "4.39.1", reason="Requires transformers version 4.39.1 or later")
|
@pytest.mark.skipif(transformers.__version__ < "4.39.1", reason="Requires transformers version 4.39.1 or later")
|
||||||
|
@pytest.mark.largedist
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
@clear_cache_before_run()
|
@clear_cache_before_run()
|
||||||
def test_qwen2_3d():
|
def test_qwen2_3d():
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
0.3.8
|
0.3.9
|
||||||
|
|
Loading…
Reference in New Issue