mirror of https://github.com/hpcaitech/ColossalAI
Merge branch 'main' of https://github.com/hpcaitech/ColossalAI into main
commit
4b59d874df
|
@ -2,7 +2,7 @@ name: Build on PR
|
|||
|
||||
on:
|
||||
pull_request:
|
||||
types: [synchronize, opened, reopened, ready_for_review, closed, edited]
|
||||
types: [synchronize, opened, reopened, ready_for_review, closed]
|
||||
branches:
|
||||
- "main"
|
||||
- "develop"
|
||||
|
|
13
README.md
13
README.md
|
@ -25,6 +25,7 @@
|
|||
</div>
|
||||
|
||||
## Latest News
|
||||
* [2024/06] [Open-Sora Continues Open Source: Generate Any 16-Second 720p HD Video with One Click, Model Weights Ready to Use](https://hpc-ai.com/blog/open-sora-from-hpc-ai-tech-team-continues-open-source-generate-any-16-second-720p-hd-video-with-one-click-model-weights-ready-to-use)
|
||||
* [2024/05] [Large AI Models Inference Speed Doubled, Colossal-Inference Open Source Release](https://hpc-ai.com/blog/colossal-inference)
|
||||
* [2024/04] [Open-Sora Unveils Major Upgrade: Embracing Open Source with Single-Shot 16-Second Video Generation and 720p Resolution](https://hpc-ai.com/blog/open-soras-comprehensive-upgrade-unveiled-embracing-16-second-video-generation-and-720p-resolution-in-open-source)
|
||||
* [2024/04] [Most cost-effective solutions for inference, fine-tuning and pretraining, tailored to LLaMA3 series](https://hpc-ai.com/blog/most-cost-effective-solutions-for-inference-fine-tuning-and-pretraining-tailored-to-llama3-series)
|
||||
|
@ -32,10 +33,6 @@
|
|||
* [2024/03] [Open-Sora: Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models](https://hpc-ai.com/blog/open-sora-v1.0)
|
||||
* [2024/03] [Open-Sora:Sora Replication Solution with 46% Cost Reduction, Sequence Expansion to Nearly a Million](https://hpc-ai.com/blog/open-sora)
|
||||
* [2024/01] [Inference Performance Improved by 46%, Open Source Solution Breaks the Length Limit of LLM for Multi-Round Conversations](https://hpc-ai.com/blog/Colossal-AI-SwiftInfer)
|
||||
* [2024/01] [Construct Refined 13B Private Model With Just $5000 USD, Upgraded Colossal-AI Llama-2 Open Source](https://hpc-ai.com/blog/colossal-llama-2-13b)
|
||||
* [2023/11] [Enhanced MoE Parallelism, Open-source MoE Model Training Can Be 9 Times More Efficient](https://www.hpc-ai.tech/blog/enhanced-moe-parallelism-open-source-moe-model-training-can-be-9-times-more-efficient)
|
||||
* [2023/09] [One Half-Day of Training Using a Few Hundred Dollars Yields Similar Results to Mainstream Large Models, Open-Source and Commercial-Free Domain-Specific LLM Solution](https://www.hpc-ai.tech/blog/one-half-day-of-training-using-a-few-hundred-dollars-yields-similar-results-to-mainstream-large-models-open-source-and-commercial-free-domain-specific-llm-solution)
|
||||
* [2023/09] [70 Billion Parameter LLaMA2 Model Training Accelerated by 195%](https://www.hpc-ai.tech/blog/70b-llama2-training)
|
||||
* [2023/07] [HPC-AI Tech Raises 22 Million USD in Series A Funding](https://www.hpc-ai.tech/blog/hpc-ai-tech-raises-22-million-usd-in-series-a-funding-to-fuel-team-expansion-and-business-growth)
|
||||
|
||||
## Table of Contents
|
||||
|
@ -132,13 +129,13 @@ distributed training and inference in a few lines.
|
|||
|
||||
[Open-Sora](https://github.com/hpcaitech/Open-Sora):Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models
|
||||
[[code]](https://github.com/hpcaitech/Open-Sora)
|
||||
[[blog]](https://hpc-ai.com/blog/open-soras-comprehensive-upgrade-unveiled-embracing-16-second-video-generation-and-720p-resolution-in-open-source)
|
||||
[[HuggingFace model weights]](https://huggingface.co/hpcai-tech/Open-Sora)
|
||||
[[blog]](https://hpc-ai.com/blog/open-sora-from-hpc-ai-tech-team-continues-open-source-generate-any-16-second-720p-hd-video-with-one-click-model-weights-ready-to-use)
|
||||
[[Model weights]](https://github.com/hpcaitech/Open-Sora?tab=readme-ov-file#model-weights)
|
||||
[[Demo]](https://github.com/hpcaitech/Open-Sora?tab=readme-ov-file#-latest-demo)
|
||||
|
||||
<div align="center">
|
||||
<a href="https://www.youtube.com/watch?v=iDTxepqixuc">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/sora/sora-demo.png" width="700" />
|
||||
<a href="https://youtu.be/ilMQpU71ddI?si=J4JSPzZ03ycYmlki">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/sora/opensora-v1.2.png" width="700" />
|
||||
</a>
|
||||
</div>
|
||||
|
||||
|
|
|
@ -999,7 +999,9 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
|
||||
|
||||
if enable_sequence_parallelism:
|
||||
self.sequence_parallelism_mode = sequence_parallelism_mode if sequence_parallelism_mode is not None else "1"
|
||||
self.sequence_parallelism_mode = (
|
||||
sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all"
|
||||
)
|
||||
assert (
|
||||
self.sequence_parallelism_mode in SUPPORT_SP_MODE
|
||||
), f"Sequence parallelism mode {self.sequence_parallelism_mode} is not in the supported list {SUPPORT_SP_MODE}"
|
||||
|
@ -1014,19 +1016,13 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
self.sp_size = 1
|
||||
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
|
||||
elif self.sequence_parallelism_mode in ["all_to_all"]:
|
||||
assert (
|
||||
tp_size == 1
|
||||
), f"Sequence parallelism mode {self.sequence_parallelism_mode} cannot be used with tensor parallelism"
|
||||
assert (
|
||||
pp_size == 1
|
||||
), f"Sequence parallelism mode {self.sequence_parallelism_mode} cannot be used with pipeline parallelism"
|
||||
self.sp_size = dist.get_world_size() if sp_size is None else sp_size
|
||||
self.dp_size = dist.get_world_size() // (self.sp_size * pp_size)
|
||||
self.sp_size = 1 if sp_size is None else sp_size
|
||||
self.dp_size = dist.get_world_size() // (self.sp_size * pp_size * tp_size)
|
||||
else:
|
||||
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
|
||||
assert (
|
||||
sp_size == 1 or sp_size is None
|
||||
), f"sp_size can only be set to a >1 number when enable_sequence_parallelism is True"
|
||||
), f"You should not set sp_size when sequence parallelism is not enabled."
|
||||
self.sp_size = 1
|
||||
|
||||
self.tp_size = tp_size
|
||||
|
@ -1040,11 +1036,22 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
self.enable_jit_fused = enable_jit_fused
|
||||
self.enable_sequence_parallelism = enable_sequence_parallelism
|
||||
if dp_outside:
|
||||
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
|
||||
(
|
||||
self.dp_axis,
|
||||
self.pp_axis,
|
||||
self.tp_axis,
|
||||
self.sp_axis,
|
||||
) = (
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
)
|
||||
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
|
||||
else:
|
||||
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
|
||||
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
|
||||
|
||||
self.stage_manager = None
|
||||
self.schedule = None
|
||||
self.custom_policy = custom_policy
|
||||
|
|
|
@ -314,7 +314,7 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors
|
|||
use_safetensors (bool): whether to use safetensors to save the checkpoint.
|
||||
"""
|
||||
# Move all tensors in the state_dict to CPU before saving to avoid serialization issues
|
||||
state_dict_cpu = tree_map(lambda x: x.cpu() if torch.is_tensor(x) else x, state_dict)
|
||||
state_dict_cpu = tree_map(lambda x: x.data.cpu() if torch.is_tensor(x) else x, state_dict)
|
||||
|
||||
if use_safetensors:
|
||||
assert is_safetensors_available(), "safetensors is not available."
|
||||
|
|
|
@ -236,7 +236,7 @@ Completion api is used for single sequence request, like answer a question or co
|
|||
- POST '/chat':
|
||||
Chat api is used for conversation-style request, which often includes dialogue participants(i.e. roles) and corresponding words. Considering the input data are very different from normal inputs, we introduce Chat-Template to match the data format in chat models.
|
||||
#### chat-template
|
||||
Followed `transformers`, we add the chat-template argument. As chat models have been trained with very different formats for converting conversations into a single tokenizable string. Using a format that matches the training data is extremely important. This attribute(chat_template) is inclueded in HuggingFace tokenizers, containing a Jinja template that converts conversation histories into a correctly formatted string. You can refer to the [HuggingFace-blog](https://huggingface.co/blog/chat-templates) for more information. We also provide a simple example temlate bellow. Both str or file style chat template are supported.
|
||||
Followed `transformers`, we add the chat-template argument. As chat models have been trained with very different formats for converting conversations into a single tokenizable string. Using a format that matches the training data is extremely important. This attribute(chat_template) is inclueded in HuggingFace tokenizers, containing a Jinja template that converts conversation histories into a correctly formatted string. You can refer to the [HuggingFace-blog](https://huggingface.co/blog/chat-templates) for more information. We also provide a simple example template bellow. Both str or file style chat template are supported.
|
||||
### Usage
|
||||
#### Args for customizing your server
|
||||
The configuration for api server contains both serving interface and engine backend.
|
||||
|
|
|
@ -10,6 +10,7 @@ import torch
|
|||
from transformers.generation import GenerationConfig
|
||||
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.inference.utils import can_use_flash_attn2
|
||||
|
||||
GibiByte = 1024**3
|
||||
|
||||
|
@ -169,7 +170,8 @@ class InferenceConfig(RPC_PARAM):
|
|||
no_repeat_ngram_size (Optional[int]): If no_repeat_ngram_size > 0, the consecutive tokens of ngram size can only appear once in inference sentences.
|
||||
repetition_penalty (Optional[float]): The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0.
|
||||
ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token.
|
||||
n_spec_tokens (int): The maximum number of speculating tokens, defaults to None.
|
||||
use_spec_dec (bool): Indicate whether to use speculative decoding, defaults to False.
|
||||
max_n_spec_tokens (int): The maximum number of speculating tokens, defaults to None.
|
||||
glimpse_large_kv (bool): Whether to use large KV in drafter model, defaults to False.
|
||||
block_size (int): The number of blocks in a logical block, defaults to 16.
|
||||
tp_size (int): Tensor parallel size, defaults to 1.
|
||||
|
@ -214,6 +216,7 @@ class InferenceConfig(RPC_PARAM):
|
|||
ignore_eos: bool = False
|
||||
|
||||
# speculative decoding configs
|
||||
use_spec_dec: bool = False
|
||||
max_n_spec_tokens: int = 5
|
||||
glimpse_large_kv: bool = False
|
||||
|
||||
|
@ -311,6 +314,16 @@ class InferenceConfig(RPC_PARAM):
|
|||
|
||||
return GenerationConfig.from_dict(meta_config)
|
||||
|
||||
def to_model_shard_inference_config(self) -> "ModelShardInferenceConfig":
|
||||
use_flash_attn = can_use_flash_attn2(self.dtype)
|
||||
model_inference_config = ModelShardInferenceConfig(
|
||||
dtype=self.dtype,
|
||||
use_cuda_kernel=self.use_cuda_kernel,
|
||||
use_spec_dec=self.use_spec_dec,
|
||||
use_flash_attn=use_flash_attn,
|
||||
)
|
||||
return model_inference_config
|
||||
|
||||
def to_rpc_param(self) -> dict:
|
||||
kwargs = {
|
||||
"dtype": str(self.dtype).split(".")[-1],
|
||||
|
@ -362,3 +375,21 @@ class InferenceConfig(RPC_PARAM):
|
|||
# Set the attributes from the parsed arguments.
|
||||
inference_config = cls(**inference_config_args)
|
||||
return inference_config
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelShardInferenceConfig:
|
||||
"""
|
||||
Configurations used during init of module for inference modeling.
|
||||
|
||||
Args:
|
||||
dtype (torch.dtype): The data type for weights and activations.
|
||||
use_cuda_kernel (bool): Whether to use cuda kernel, faster but lose some precision occasionally
|
||||
use_spec_dec (bool): Indicate whether to use speculative decoding.
|
||||
use_flash_attn (bool): Indicate whether to use flash attention.
|
||||
"""
|
||||
|
||||
dtype: torch.dtype = None
|
||||
use_cuda_kernel: bool = False
|
||||
use_spec_dec: bool = False
|
||||
use_flash_attn: bool = False
|
||||
|
|
|
@ -18,7 +18,7 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
|||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.inference.batch_bucket import BatchBucket
|
||||
from colossalai.inference.config import InferenceConfig, InputMetaData
|
||||
from colossalai.inference.config import InferenceConfig, InputMetaData, ModelShardInferenceConfig
|
||||
from colossalai.inference.graph_runner import CUDAGraphRunner
|
||||
from colossalai.inference.modeling.policy import model_policy_map
|
||||
from colossalai.inference.sampler import search_tokens
|
||||
|
@ -72,8 +72,9 @@ class InferenceEngine:
|
|||
|
||||
self.verbose = verbose
|
||||
self.logger = get_dist_logger(__name__)
|
||||
self.model_shard_infer_config = inference_config.to_model_shard_inference_config()
|
||||
|
||||
self.init_model(model_or_path, model_policy)
|
||||
self.init_model(model_or_path, model_policy, self.model_shard_infer_config)
|
||||
|
||||
self.generation_config = inference_config.to_generation_config(self.model_config)
|
||||
self.generation_config_dict = self.generation_config.to_dict()
|
||||
|
@ -97,7 +98,8 @@ class InferenceEngine:
|
|||
self.capture_model(self.k_cache, self.v_cache)
|
||||
|
||||
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
|
||||
self.use_spec_dec = False
|
||||
self.use_spec_dec = self.inference_config.use_spec_dec
|
||||
|
||||
self.drafter_model = None
|
||||
self.drafter = None
|
||||
self.use_glide = False
|
||||
|
@ -105,13 +107,20 @@ class InferenceEngine:
|
|||
|
||||
self._verify_args()
|
||||
|
||||
def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Union[Policy, Type[Policy]] = None):
|
||||
def init_model(
|
||||
self,
|
||||
model_or_path: Union[nn.Module, str],
|
||||
model_policy: Union[Policy, Type[Policy]] = None,
|
||||
model_shard_infer_config: ModelShardInferenceConfig = None,
|
||||
):
|
||||
"""
|
||||
Shard model or/and Load weight
|
||||
|
||||
Args:
|
||||
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
|
||||
model_policy (Policy): the policy to replace the model
|
||||
model_policy (Policy): the policy to replace the model.
|
||||
model_inference_config: the configuration for modeling initialization when inference.
|
||||
model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
|
||||
"""
|
||||
|
||||
if isinstance(model_or_path, str):
|
||||
|
@ -124,6 +133,7 @@ class InferenceEngine:
|
|||
# the model load process in the future.
|
||||
model = _supported_models[arch].from_pretrained(model_or_path, trust_remote_code=True)
|
||||
else:
|
||||
# TODO(char-1ee): if the model not supported, use transformers APIs to load and generate
|
||||
raise ValueError(f"Model {arch} is not supported.")
|
||||
|
||||
except Exception as e:
|
||||
|
@ -167,6 +177,7 @@ class InferenceEngine:
|
|||
self.model = self._shardformer(
|
||||
model,
|
||||
model_policy,
|
||||
model_shard_infer_config,
|
||||
None,
|
||||
tp_group=tp_group,
|
||||
)
|
||||
|
@ -187,7 +198,7 @@ class InferenceEngine:
|
|||
# assert if_has_index_file, "the model path is invalid"
|
||||
# cpt_io.load_model(self.model, model_index_file)
|
||||
|
||||
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
|
||||
free_gpu_memory, _ = torch.cuda.mem_get_info()
|
||||
peak_memory = init_gpu_memory - free_gpu_memory
|
||||
if self.verbose:
|
||||
self.logger.info(
|
||||
|
@ -287,6 +298,7 @@ class InferenceEngine:
|
|||
self,
|
||||
model: nn.Module,
|
||||
model_policy: Policy,
|
||||
model_shard_infer_config: ModelShardInferenceConfig = None,
|
||||
stage_manager: PipelineStageManager = None,
|
||||
tp_group: ProcessGroupMesh = None,
|
||||
) -> nn.Module:
|
||||
|
@ -312,6 +324,7 @@ class InferenceEngine:
|
|||
enable_flash_attention=False,
|
||||
enable_jit_fused=False,
|
||||
enable_sequence_parallelism=False,
|
||||
extra_kwargs={"model_shard_infer_config": model_shard_infer_config},
|
||||
)
|
||||
shardformer = ShardFormer(shard_config=shardconfig)
|
||||
shard_model, _ = shardformer.optimize(model, model_policy)
|
||||
|
@ -348,6 +361,7 @@ class InferenceEngine:
|
|||
engine.clear_spec_dec()
|
||||
```
|
||||
"""
|
||||
|
||||
if drafter_model is None and self.drafter is None:
|
||||
raise ValueError("Drafter not initialized. Please provide a Drafter Model")
|
||||
if n_spec_tokens is not None:
|
||||
|
@ -452,6 +466,7 @@ class InferenceEngine:
|
|||
self.k_cache[-1], # use kv cahces of the last layer
|
||||
self.v_cache[-1],
|
||||
batch.get_sequence_lengths(),
|
||||
n_spec_tokens=self.n_spec_tokens,
|
||||
)
|
||||
|
||||
drafter_out = self.drafter.speculate(
|
||||
|
@ -517,19 +532,19 @@ class InferenceEngine:
|
|||
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
||||
return_token_ids: bool = False,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
) -> List[str]:
|
||||
) -> Union[List[str], Tuple[List[str], List[List[int]]]]:
|
||||
"""
|
||||
Executing the inference step.
|
||||
|
||||
Args:
|
||||
prompts (Union[List[str], optional): Input prompts. Defaults to None.
|
||||
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
|
||||
request_ids (List[int], optional): The request ID. Defaults to None.
|
||||
return_token_ids (bool): Whether to return output token ids. Defaults to False.
|
||||
generation_config (GenerationConfig, optional): Huggingface GenerationConfig used for inference. Defaults to None.
|
||||
prompts (Union[List[str], optional): Input prompts. Defaults to None.
|
||||
prompts_token_ids (Union[List[int], torch.Tensor, np.ndarray], optional): token ids of input prompts. Defaults to None.
|
||||
return_token_ids (bool, optional): Whether to return output token ids. Defaults to False.
|
||||
generation_config (Optional[GenerationConfig], optional): Huggingface GenerationConfig used for inference. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List[str]: Inference result returned by one generation.
|
||||
Union[List[str], Tuple[List[str], List[List[int]]]]: Inference result returned by one generation.
|
||||
"""
|
||||
|
||||
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
|
||||
|
|
|
@ -0,0 +1,170 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.inference.config import ModelShardInferenceConfig
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.kernel.triton import context_attention_unpadded, flash_decoding_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttentionMetaData:
|
||||
query_states: torch.Tensor
|
||||
key_states: torch.Tensor
|
||||
value_states: torch.Tensor
|
||||
k_cache: torch.Tensor
|
||||
v_cache: torch.Tensor
|
||||
block_tables: torch.Tensor
|
||||
block_size: int
|
||||
kv_seq_len: int = None
|
||||
sequence_lengths: torch.Tensor = None
|
||||
cu_seqlens: torch.Tensor = None
|
||||
sm_scale: int = None
|
||||
alibi_slopes: torch.Tensor = None
|
||||
output_tensor: torch.Tensor = None
|
||||
use_spec_dec: bool = False
|
||||
use_alibi_attn: bool = False
|
||||
|
||||
|
||||
class AttentionBackend(ABC):
|
||||
@abstractmethod
|
||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def decode(self, attn_metadatas: AttentionMetaData, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CudaAttentionBackend(AttentionBackend):
|
||||
"""
|
||||
Attention backend when use_cuda_kernel is True but flash-attn not found. If flash-attn is not found,
|
||||
it uses Triton op `context_attention_unpadded` for prefilling and our cuda op `flash_decoding_attention` for decoding.
|
||||
"""
|
||||
|
||||
def __init__(self, use_flash_attn: bool = False):
|
||||
super().__init__()
|
||||
self.inference_ops = InferenceOpsLoader().load()
|
||||
self.use_flash_attn = use_flash_attn
|
||||
|
||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
if self.use_flash_attn:
|
||||
token_nums = kwargs.get("token_nums", -1)
|
||||
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
|
||||
attn_output = flash_attn_varlen_func(
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.key_states,
|
||||
attn_metadata.value_states,
|
||||
cu_seqlens_q=attn_metadata.cu_seqlens,
|
||||
cu_seqlens_k=attn_metadata.cu_seqlens,
|
||||
max_seqlen_q=attn_metadata.kv_seq_len,
|
||||
max_seqlen_k=attn_metadata.kv_seq_len,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=attn_metadata.sm_scale,
|
||||
causal=True,
|
||||
alibi_slopes=attn_metadata.alibi_slopes,
|
||||
)
|
||||
attn_output = attn_output.view(token_nums, -1)
|
||||
else:
|
||||
attn_output = context_attention_unpadded(
|
||||
q=attn_metadata.query_states,
|
||||
k=attn_metadata.key_states,
|
||||
v=attn_metadata.value_states,
|
||||
k_cache=attn_metadata.k_cache,
|
||||
v_cache=attn_metadata.v_cache,
|
||||
context_lengths=attn_metadata.sequence_lengths,
|
||||
block_tables=attn_metadata.block_tables,
|
||||
block_size=attn_metadata.block_size,
|
||||
output=attn_metadata.output_tensor,
|
||||
alibi_slopes=attn_metadata.alibi_slopes,
|
||||
max_seq_len=attn_metadata.kv_seq_len,
|
||||
sm_scale=attn_metadata.sm_scale,
|
||||
use_new_kcache_layout=True, # use new k-cache layout
|
||||
)
|
||||
return attn_output
|
||||
|
||||
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
fd_inter_tensor = kwargs.get("fd_inter_tensor", None)
|
||||
output_tensor = attn_metadata.output_tensor
|
||||
self.inference_ops.flash_decoding_attention(
|
||||
output_tensor,
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.k_cache,
|
||||
attn_metadata.v_cache,
|
||||
attn_metadata.sequence_lengths,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.block_size,
|
||||
attn_metadata.kv_seq_len,
|
||||
fd_inter_tensor.mid_output,
|
||||
fd_inter_tensor.exp_sums,
|
||||
fd_inter_tensor.max_logits,
|
||||
attn_metadata.alibi_slopes,
|
||||
attn_metadata.sm_scale,
|
||||
)
|
||||
return output_tensor
|
||||
|
||||
|
||||
class TritonAttentionBackend(AttentionBackend):
|
||||
"""
|
||||
Attention backend when use_cuda_kernel is False. It uses pure Triton ops for prefilling and decoding.
|
||||
"""
|
||||
|
||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
return context_attention_unpadded(
|
||||
q=attn_metadata.query_states,
|
||||
k=attn_metadata.key_states,
|
||||
v=attn_metadata.value_states,
|
||||
k_cache=attn_metadata.k_cache,
|
||||
v_cache=attn_metadata.v_cache,
|
||||
context_lengths=attn_metadata.sequence_lengths,
|
||||
block_tables=attn_metadata.block_tables,
|
||||
block_size=attn_metadata.block_size,
|
||||
output=attn_metadata.output_tensor,
|
||||
alibi_slopes=attn_metadata.alibi_slopes,
|
||||
max_seq_len=attn_metadata.kv_seq_len,
|
||||
sm_scale=attn_metadata.sm_scale,
|
||||
)
|
||||
|
||||
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
fd_inter_tensor = kwargs.get("fd_inter_tensor", None)
|
||||
return flash_decoding_attention(
|
||||
q=attn_metadata.query_states,
|
||||
k_cache=attn_metadata.k_cache,
|
||||
v_cache=attn_metadata.v_cache,
|
||||
kv_seq_len=attn_metadata.sequence_lengths,
|
||||
block_tables=attn_metadata.block_tables,
|
||||
block_size=attn_metadata.block_size,
|
||||
max_seq_len_in_batch=attn_metadata.kv_seq_len,
|
||||
output=attn_metadata.output_tensor,
|
||||
mid_output=fd_inter_tensor.mid_output,
|
||||
mid_output_lse=fd_inter_tensor.mid_output_lse,
|
||||
alibi_slopes=attn_metadata.alibi_slopes,
|
||||
sm_scale=attn_metadata.sm_scale,
|
||||
kv_group_num=kwargs.get("num_key_value_groups", 1),
|
||||
q_len=kwargs.get("q_len", 1),
|
||||
)
|
||||
|
||||
|
||||
def get_attention_backend(
|
||||
model_shard_infer_config: ModelShardInferenceConfig,
|
||||
) -> AttentionBackend:
|
||||
"""
|
||||
Get the attention backend based on the inference configurations. The modeling will use CUDA-kernel-based backend
|
||||
for attention module calculation only when:
|
||||
1. using CUDA kernel (use_cuda_kernel=True)
|
||||
2. can use flash attention (flash-attn installed and dtype is fp16 or bf16)
|
||||
3. not using speculative decoding (currently cuda kernel not support speculative decoding)
|
||||
Otherwise, use Triton attention backend. If found flash-attn not installed while `use_cuda_kernel` is True,
|
||||
the Triton backend will use a new k cache layout for Triton kernels.
|
||||
"""
|
||||
# Currently only triton kernels support speculative decoding
|
||||
if model_shard_infer_config.use_spec_dec:
|
||||
return TritonAttentionBackend()
|
||||
|
||||
if model_shard_infer_config.use_cuda_kernel:
|
||||
return CudaAttentionBackend(model_shard_infer_config.use_flash_attn)
|
||||
|
||||
return TritonAttentionBackend()
|
|
@ -0,0 +1,146 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
from colossalai.inference.config import ModelShardInferenceConfig
|
||||
from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.kernel.triton import copy_k_to_blocked_cache, decoding_fused_rotary_embedding, rotary_embedding
|
||||
|
||||
|
||||
class PreAttentionBackend(ABC):
|
||||
@abstractmethod
|
||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CudaPreAttentionBackend(PreAttentionBackend):
|
||||
"""
|
||||
CudaPreAttentionBackend handles KV cache initialization and positional encoding for CudaAttentionBackend.
|
||||
"""
|
||||
|
||||
def __init__(self, use_flash_attn: bool):
|
||||
super().__init__()
|
||||
self.inference_ops = InferenceOpsLoader().load()
|
||||
self.use_flash_attn = use_flash_attn
|
||||
|
||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
if self.use_flash_attn:
|
||||
if not attn_metadata.use_alibi_attn:
|
||||
self.inference_ops.rotary_embedding(
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.key_states,
|
||||
kwargs.get("cos", None),
|
||||
kwargs.get("sin", None),
|
||||
kwargs.get("high_precision", False),
|
||||
)
|
||||
self.inference_ops.context_kv_cache_memcpy(
|
||||
attn_metadata.key_states,
|
||||
attn_metadata.value_states,
|
||||
attn_metadata.k_cache,
|
||||
attn_metadata.v_cache,
|
||||
attn_metadata.sequence_lengths,
|
||||
attn_metadata.cu_seqlens,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.kv_seq_len,
|
||||
)
|
||||
elif not attn_metadata.use_alibi_attn:
|
||||
rotary_embedding(
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.key_states,
|
||||
kwargs.get("cos", None),
|
||||
kwargs.get("sin", None),
|
||||
)
|
||||
|
||||
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
if not attn_metadata.use_alibi_attn:
|
||||
self.inference_ops.rotary_embedding_and_cache_copy(
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.key_states,
|
||||
attn_metadata.value_states,
|
||||
kwargs.get("cos", None),
|
||||
kwargs.get("sin", None),
|
||||
attn_metadata.k_cache,
|
||||
attn_metadata.v_cache,
|
||||
attn_metadata.sequence_lengths,
|
||||
attn_metadata.block_tables,
|
||||
kwargs.get("high_precision", None),
|
||||
)
|
||||
else:
|
||||
self.inference_ops.decode_kv_cache_memcpy(
|
||||
attn_metadata.key_states,
|
||||
attn_metadata.value_states,
|
||||
attn_metadata.k_cache,
|
||||
attn_metadata.v_cache,
|
||||
attn_metadata.sequence_lengths,
|
||||
attn_metadata.block_tables,
|
||||
)
|
||||
|
||||
|
||||
class TritonPreAttentionBackend(PreAttentionBackend):
|
||||
"""
|
||||
TritonPreAttentionBackend handles KV cache initialization and positional encoding for TritonAttentionBackend.
|
||||
"""
|
||||
|
||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
if not attn_metadata.use_alibi_attn:
|
||||
rotary_embedding(
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.key_states,
|
||||
kwargs.get("cos", None),
|
||||
kwargs.get("sin", None),
|
||||
)
|
||||
|
||||
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
if not attn_metadata.use_spec_dec and not attn_metadata.use_alibi_attn:
|
||||
decoding_fused_rotary_embedding(
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.key_states,
|
||||
attn_metadata.value_states,
|
||||
kwargs.get("cos", None),
|
||||
kwargs.get("sin", None),
|
||||
attn_metadata.k_cache,
|
||||
attn_metadata.v_cache,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.sequence_lengths,
|
||||
)
|
||||
else: # else if using speculative decoding
|
||||
if not attn_metadata.use_alibi_attn:
|
||||
rotary_embedding(
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.key_states,
|
||||
kwargs.get("cos", None),
|
||||
kwargs.get("sin", None),
|
||||
)
|
||||
copy_k_to_blocked_cache(
|
||||
attn_metadata.key_states,
|
||||
attn_metadata.k_cache,
|
||||
kv_lengths=attn_metadata.sequence_lengths,
|
||||
block_tables=attn_metadata.block_tables,
|
||||
n=kwargs.get("q_len", 1),
|
||||
)
|
||||
copy_k_to_blocked_cache(
|
||||
attn_metadata.value_states,
|
||||
attn_metadata.v_cache,
|
||||
kv_lengths=attn_metadata.sequence_lengths,
|
||||
block_tables=attn_metadata.block_tables,
|
||||
n=kwargs.get("q_len", 1),
|
||||
)
|
||||
|
||||
|
||||
def get_pre_attention_backend(
|
||||
model_shard_infer_config: ModelShardInferenceConfig,
|
||||
) -> PreAttentionBackend:
|
||||
"""
|
||||
Get the backend for pre-attention computations, including potisional encoding like
|
||||
RoPE and KV cache initialization. It adopt the same selection logic as attention_backend/get_attention_backend.
|
||||
"""
|
||||
if model_shard_infer_config.use_spec_dec:
|
||||
return TritonPreAttentionBackend()
|
||||
|
||||
if model_shard_infer_config.use_cuda_kernel:
|
||||
return CudaPreAttentionBackend(model_shard_infer_config.use_flash_attn)
|
||||
|
||||
return TritonPreAttentionBackend()
|
|
@ -15,25 +15,10 @@ class BaichuanLMHeadLinear1D_Col(Linear1D_Col):
|
|||
module.in_features = module.weight.size(1)
|
||||
module.out_features = module.weight.size(0)
|
||||
module.bias = None
|
||||
module.weight.data = nn.functional.normalize(module.weight)
|
||||
|
||||
return Linear1D_Col.from_native_module(
|
||||
module,
|
||||
process_group,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class BaichuanWpackLinear1D_Col(Linear1D_Col):
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
) -> ParallelModule:
|
||||
in_features = module.in_features * 3
|
||||
out_features = module.out_features // 3
|
||||
module.weight.data = module.weight.view(3, out_features, -1).transpose(0, 1).reshape(out_features, in_features)
|
||||
module.bias = None
|
||||
module.weight.data = nn.functional.normalize(
|
||||
module.weight
|
||||
) # TODO(lry89757) This behavior may not apply to lazy init. When we use lazy init, the weight of shardformer is not the real weight.
|
||||
# So we should rewrite our own load_from_state_dict of `BaichuanLMHeadLinear1D_Col` to fix this potential issue.
|
||||
|
||||
return Linear1D_Col.from_native_module(
|
||||
module,
|
||||
|
|
|
@ -6,11 +6,7 @@ from typing import List, Optional, Tuple, Union
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers.cache_utils import Cache, DynamicCache
|
||||
from transformers.modeling_attn_mask_utils import (
|
||||
_prepare_4d_causal_attention_mask,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
from transformers.cache_utils import Cache, DynamicCache, StaticCache
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
|
@ -137,6 +133,7 @@ def glide_llama_model_forward(
|
|||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
|
@ -147,57 +144,43 @@ def glide_llama_model_forward(
|
|||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape[:2]
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length = inputs_embeds.shape[:2]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
past_key_values_length = 0
|
||||
if use_cache:
|
||||
use_legacy_cache = not isinstance(past_key_values, Cache)
|
||||
if use_legacy_cache:
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
||||
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError(
|
||||
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0)
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.")
|
||||
use_cache = False
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
elif self._use_sdpa and not output_attentions:
|
||||
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
||||
# the manual implementation that requires a 4D causal mask in all cases.
|
||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
)
|
||||
else:
|
||||
# 4d mask is passed through the layers
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
||||
past_seen_tokens = 0
|
||||
if use_cache: # kept for BC (cache positions)
|
||||
if not isinstance(past_key_values, StaticCache):
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
past_seen_tokens = past_key_values.get_seq_length()
|
||||
|
||||
if cache_position is None:
|
||||
if isinstance(past_key_values, StaticCache):
|
||||
raise ValueError("cache_position is a required argument when using StaticCache.")
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
|
||||
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
next_decoder_cache = None
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
|
@ -212,6 +195,7 @@ def glide_llama_model_forward(
|
|||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
@ -230,7 +214,9 @@ def glide_llama_model_forward(
|
|||
|
||||
next_cache = None
|
||||
if use_cache:
|
||||
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
||||
next_cache = (
|
||||
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
|
||||
)
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
|
@ -333,7 +319,8 @@ class LlamaCrossAttention(nn.Module):
|
|||
query_states = query_states.view(bsz, -1, self.large_num_heads, self.large_head_dim).transpose(1, 2)
|
||||
|
||||
# for RoPE
|
||||
cos, sin = self.rotary_emb(query_states, seq_len=kv_seq_len + 32)
|
||||
position_ids = position_ids + glide_input.n_spec_tokens
|
||||
cos, sin = self.rotary_emb(query_states, position_ids)
|
||||
query_states = apply_single_rotary_pos_emb(query_states, cos, sin, position_ids)
|
||||
query_states = query_states.transpose(1, 2)
|
||||
query_states = query_states.reshape(-1, self.large_num_heads, self.large_head_dim)
|
||||
|
|
|
@ -1,68 +1,27 @@
|
|||
# This code is adapted from huggingface baichuan model: hhttps://huggingface.co/baichuan-inc/Baichuan2-13B-Base/blob/main/modeling_baichuan.py
|
||||
import itertools
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.inference.config import ModelShardInferenceConfig
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData, get_attention_backend
|
||||
from colossalai.inference.modeling.backends.pre_attention_backend import get_pre_attention_backend
|
||||
from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaMLP
|
||||
from colossalai.inference.utils import get_alibi_slopes
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.kernel.triton import (
|
||||
context_attention_unpadded,
|
||||
copy_k_to_blocked_cache,
|
||||
decoding_fused_rotary_embedding,
|
||||
flash_decoding_attention,
|
||||
rms_layernorm,
|
||||
rotary_embedding,
|
||||
)
|
||||
from colossalai.kernel.triton import rms_layernorm
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.shardformer.layer.parallel_module import ParallelModule
|
||||
from colossalai.tensor.d_tensor import Layout, distribute_tensor, is_distributed_tensor
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
|
||||
use_flash_attn2 = True
|
||||
except ImportError:
|
||||
use_flash_attn2 = False
|
||||
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
|
||||
use_flash_attn2 = True
|
||||
except ImportError:
|
||||
use_flash_attn2 = False
|
||||
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
|
||||
from colossalai.tensor.d_tensor import is_distributed_tensor
|
||||
|
||||
inference_ops = InferenceOpsLoader().load()
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
|
||||
# alibi slopes calculation adapted from https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/models/bloom/modeling_bloom.py#L57
|
||||
def get_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor:
|
||||
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
|
||||
base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, device=device)
|
||||
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32, device=device)
|
||||
slopes = torch.pow(base, powers)
|
||||
if closest_power_of_2 != num_heads:
|
||||
extra_base = torch.tensor(
|
||||
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, device=device
|
||||
)
|
||||
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
|
||||
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32, device=device)
|
||||
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
||||
return slopes
|
||||
|
||||
|
||||
def baichuan_rmsnorm_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
@ -96,23 +55,19 @@ class NopadBaichuanAttention(ParallelModule):
|
|||
def __init__(
|
||||
self,
|
||||
config,
|
||||
attn_qproj_w: torch.Tensor = None,
|
||||
attn_kproj_w: torch.Tensor = None,
|
||||
attn_vproj_w: torch.Tensor = None,
|
||||
W_pack: ParallelModule = None,
|
||||
attn_oproj: ParallelModule = None,
|
||||
num_heads: int = None,
|
||||
hidden_size: int = None,
|
||||
model_shard_infer_config: ModelShardInferenceConfig = None,
|
||||
process_group: ProcessGroup = None,
|
||||
helper_layout: Layout = None,
|
||||
):
|
||||
"""This layer will replace the BaichuanAttention.
|
||||
|
||||
Args:
|
||||
config (BaichuanConfig): Holding the Baichuan model config.
|
||||
attn_qproj_w (torch.Tensor, optional): The transposed q_proj weight. Defaults to None.
|
||||
attn_kproj_w (torch.Tensor, optional): The transposed k_proj weight. Defaults to None.
|
||||
attn_vproj_w (torch.Tensor, optional): The transposed v_proj weight. Defaults to None.
|
||||
attn_oproj (Linear1D_Row, optional): The Linear1D_Row o_proj weight. Defaults to None.
|
||||
W_pack (ParallelModule, optional): The packed weight. Defaults to None.
|
||||
attn_oproj (Linear1D_Row, optional): The Linear1D_Row o_proj. Defaults to None.
|
||||
"""
|
||||
ParallelModule.__init__(self)
|
||||
self.o_proj = attn_oproj
|
||||
|
@ -122,10 +77,10 @@ class NopadBaichuanAttention(ParallelModule):
|
|||
self.hidden_size = hidden_size
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.process_group = process_group
|
||||
qkv_weight_list = [attn_qproj_w.transpose(0, 1), attn_kproj_w.transpose(0, 1), attn_vproj_w.transpose(0, 1)]
|
||||
self.qkv_weight = nn.Parameter(torch.stack(qkv_weight_list, dim=0))
|
||||
|
||||
self.helper_layout = helper_layout
|
||||
self.W_pack = W_pack
|
||||
self.use_cuda_kernel = model_shard_infer_config.use_cuda_kernel
|
||||
self.attention_backend = get_attention_backend(model_shard_infer_config)
|
||||
self.pre_attention_backend = get_pre_attention_backend(model_shard_infer_config)
|
||||
|
||||
self.alibi_slopes = None
|
||||
self.use_alibi_attn = False
|
||||
|
@ -133,9 +88,9 @@ class NopadBaichuanAttention(ParallelModule):
|
|||
if config.hidden_size == 5120:
|
||||
slopes_start = self.process_group.rank() * num_heads
|
||||
self.use_alibi_attn = True
|
||||
self.alibi_slopes = get_alibi_slopes(config.num_attention_heads, device=attn_qproj_w.device)[
|
||||
slopes_start : slopes_start + num_heads
|
||||
].contiguous()
|
||||
self.alibi_slopes = get_alibi_slopes(
|
||||
config.num_attention_heads, device=get_accelerator().get_current_device()
|
||||
)[slopes_start : slopes_start + num_heads].contiguous()
|
||||
self.alibi_slopes = nn.Parameter(self.alibi_slopes)
|
||||
|
||||
@staticmethod
|
||||
|
@ -149,76 +104,22 @@ class NopadBaichuanAttention(ParallelModule):
|
|||
"""
|
||||
|
||||
config = module.config
|
||||
q_proj_w, k_proj_w, v_proj_w = module.W_pack.weight.view((module.hidden_size, 3, -1)).transpose(0, 1)
|
||||
|
||||
attn_qproj_w = q_proj_w
|
||||
attn_kproj_w = k_proj_w
|
||||
attn_vproj_w = v_proj_w
|
||||
W_pack = module.W_pack
|
||||
attn_oproj = module.o_proj
|
||||
|
||||
helper_layout = (
|
||||
module.W_pack.weight.dist_layout
|
||||
) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict)
|
||||
model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
|
||||
|
||||
attn_layer = NopadBaichuanAttention(
|
||||
config=config,
|
||||
attn_qproj_w=attn_qproj_w,
|
||||
attn_kproj_w=attn_kproj_w,
|
||||
attn_vproj_w=attn_vproj_w,
|
||||
W_pack=W_pack,
|
||||
attn_oproj=attn_oproj,
|
||||
model_shard_infer_config=model_shard_infer_config,
|
||||
num_heads=module.num_heads,
|
||||
hidden_size=module.hidden_size,
|
||||
process_group=process_group,
|
||||
helper_layout=helper_layout,
|
||||
)
|
||||
|
||||
return attn_layer
|
||||
|
||||
def _load_from_state_dict(
|
||||
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
):
|
||||
for hook in self._load_state_dict_pre_hooks.values():
|
||||
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
|
||||
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
|
||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||
|
||||
key = "qkv_weight"
|
||||
qkv_w = state_dict[prefix + "W_pack.weight"]
|
||||
|
||||
in_features = qkv_w.size(1)
|
||||
out_features = qkv_w.size(0) // 3
|
||||
|
||||
qkv_w.data = qkv_w.view((3, out_features, -1)).transpose(0, 1).reshape(out_features, in_features * 3)
|
||||
|
||||
device_mesh = self.helper_layout.device_mesh
|
||||
sharding_spec = self.helper_layout.sharding_spec
|
||||
qkv_w = distribute_tensor(qkv_w, device_mesh, sharding_spec)
|
||||
|
||||
qkv_w = qkv_w.transpose(0, 1).reshape(3, in_features, -1)
|
||||
input_param = nn.Parameter(
|
||||
qkv_w
|
||||
) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
|
||||
|
||||
param = local_state[key]
|
||||
|
||||
try:
|
||||
with torch.no_grad():
|
||||
param.copy_(input_param)
|
||||
except Exception as ex:
|
||||
error_msgs.append(
|
||||
'While copying the parameter named "{}", '
|
||||
"whose dimensions in the model are {} and "
|
||||
"whose dimensions in the checkpoint are {}, "
|
||||
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
|
||||
)
|
||||
|
||||
strict = False # to avoid unexpected_keys
|
||||
super()._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
@ -234,7 +135,6 @@ class NopadBaichuanAttention(ParallelModule):
|
|||
kv_seq_len: int = 0,
|
||||
output_tensor: torch.Tensor = None,
|
||||
sm_scale: int = None,
|
||||
use_cuda_kernel: bool = True,
|
||||
cu_seqlens: torch.Tensor = None,
|
||||
high_precision: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
|
@ -253,144 +153,66 @@ class NopadBaichuanAttention(ParallelModule):
|
|||
kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
|
||||
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
|
||||
sm_scale (int, optional): Used for flash attention. Defaults to None.
|
||||
use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
|
||||
cu_seqlens(torch.Tensor, optional): Holding the cumulative sum of sequence length.
|
||||
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
|
||||
"""
|
||||
|
||||
token_nums = hidden_states.size(0)
|
||||
# fused qkv
|
||||
hidden_states = hidden_states.expand(3, -1, -1)
|
||||
query_states, key_states, value_states = (
|
||||
torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0)
|
||||
)
|
||||
|
||||
proj = self.W_pack(hidden_states)
|
||||
proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
|
||||
query_states = proj[0].view(token_nums, self.num_heads, self.head_dim)
|
||||
key_states = proj[1].view(token_nums, self.num_heads, self.head_dim)
|
||||
value_states = proj[2].view(token_nums, self.num_heads, self.head_dim)
|
||||
|
||||
block_size = k_cache.size(-2)
|
||||
|
||||
if is_prompts:
|
||||
if not is_verifier and use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2:
|
||||
# flash attn 2 currently only supports FP16/BF16.
|
||||
if not self.use_alibi_attn:
|
||||
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision)
|
||||
inference_ops.context_kv_cache_memcpy(
|
||||
key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len
|
||||
)
|
||||
attn_output = flash_attn_varlen_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=kv_seq_len,
|
||||
max_seqlen_k=kv_seq_len,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=sm_scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
)
|
||||
attn_output = attn_output.view(token_nums, -1)
|
||||
else:
|
||||
if not self.use_alibi_attn:
|
||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||
attn_output = context_attention_unpadded(
|
||||
q=query_states,
|
||||
k=key_states,
|
||||
v=value_states,
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
context_lengths=sequence_lengths,
|
||||
block_tables=block_tables,
|
||||
block_size=block_size,
|
||||
output=output_tensor,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
max_seq_len=kv_seq_len,
|
||||
sm_scale=sm_scale,
|
||||
use_new_kcache_layout=use_cuda_kernel,
|
||||
)
|
||||
else:
|
||||
attn_metadata = AttentionMetaData(
|
||||
query_states=query_states,
|
||||
key_states=key_states,
|
||||
value_states=value_states,
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
block_tables=block_tables,
|
||||
block_size=block_size,
|
||||
kv_seq_len=kv_seq_len,
|
||||
sequence_lengths=sequence_lengths,
|
||||
sm_scale=sm_scale,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
cu_seqlens=cu_seqlens,
|
||||
output_tensor=output_tensor,
|
||||
use_spec_dec=is_verifier,
|
||||
use_alibi_attn=self.use_alibi_attn,
|
||||
)
|
||||
|
||||
if is_prompts: # prefilling stage
|
||||
self.pre_attention_backend.prefill(
|
||||
attn_metadata,
|
||||
cos=cos_sin[0],
|
||||
sin=cos_sin[1],
|
||||
high_precision=high_precision,
|
||||
)
|
||||
attn_output = self.attention_backend.prefill(
|
||||
attn_metadata,
|
||||
token_nums=token_nums,
|
||||
)
|
||||
else: # decoding stage
|
||||
q_len = tokens_to_verify + 1 if is_verifier else 1
|
||||
|
||||
if use_cuda_kernel:
|
||||
if not self.use_alibi_attn:
|
||||
inference_ops.rotary_embedding_and_cache_copy(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cos_sin[0],
|
||||
cos_sin[1],
|
||||
k_cache,
|
||||
v_cache,
|
||||
sequence_lengths,
|
||||
block_tables,
|
||||
high_precision,
|
||||
)
|
||||
else:
|
||||
inference_ops.decode_kv_cache_memcpy(
|
||||
key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables
|
||||
)
|
||||
inference_ops.flash_decoding_attention(
|
||||
output_tensor,
|
||||
query_states,
|
||||
k_cache,
|
||||
v_cache,
|
||||
sequence_lengths,
|
||||
block_tables,
|
||||
block_size,
|
||||
kv_seq_len,
|
||||
fd_inter_tensor.mid_output,
|
||||
fd_inter_tensor.exp_sums,
|
||||
fd_inter_tensor.max_logits,
|
||||
self.alibi_slopes,
|
||||
sm_scale,
|
||||
)
|
||||
attn_output = output_tensor
|
||||
else:
|
||||
if not is_verifier and not self.use_alibi_attn:
|
||||
decoding_fused_rotary_embedding(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cos_sin[0],
|
||||
cos_sin[1],
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_tables,
|
||||
sequence_lengths,
|
||||
)
|
||||
else:
|
||||
if not self.use_alibi_attn:
|
||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||
copy_k_to_blocked_cache(
|
||||
key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
|
||||
)
|
||||
copy_k_to_blocked_cache(
|
||||
value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
|
||||
)
|
||||
|
||||
attn_output = flash_decoding_attention(
|
||||
q=query_states,
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
kv_seq_len=sequence_lengths,
|
||||
block_tables=block_tables,
|
||||
block_size=block_size,
|
||||
max_seq_len_in_batch=kv_seq_len,
|
||||
output=output_tensor,
|
||||
mid_output=fd_inter_tensor.mid_output,
|
||||
mid_output_lse=fd_inter_tensor.mid_output_lse,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
sm_scale=sm_scale,
|
||||
q_len=q_len,
|
||||
)
|
||||
self.pre_attention_backend.decode(
|
||||
attn_metadata,
|
||||
q_len=q_len,
|
||||
)
|
||||
attn_output = self.attention_backend.decode(
|
||||
attn_metadata,
|
||||
fd_inter_tensor=fd_inter_tensor,
|
||||
q_len=q_len,
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(-1, self.hidden_size)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f"qkv_weight_proj MergedLinear1D_Col: in_features={self.qkv_weight.shape[1]}x3, out_features={self.qkv_weight.shape[2]}, bias=False"
|
||||
|
||||
|
||||
# NOTE This will cause difference as out length increases.
|
||||
class NopadBaichuanMLP(NopadLlamaMLP):
|
||||
|
|
|
@ -16,18 +16,13 @@ from transformers.models.llama.modeling_llama import (
|
|||
LlamaRMSNorm,
|
||||
)
|
||||
|
||||
from colossalai.inference.config import InputMetaData
|
||||
from colossalai.inference.config import InputMetaData, ModelShardInferenceConfig
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData, get_attention_backend
|
||||
from colossalai.inference.modeling.backends.pre_attention_backend import get_pre_attention_backend
|
||||
from colossalai.inference.utils import can_use_flash_attn2
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.kernel.triton import (
|
||||
context_attention_unpadded,
|
||||
copy_k_to_blocked_cache,
|
||||
decoding_fused_rotary_embedding,
|
||||
flash_decoding_attention,
|
||||
get_xine_cache,
|
||||
rms_layernorm,
|
||||
rotary_embedding,
|
||||
)
|
||||
from colossalai.kernel.triton import get_xine_cache, rms_layernorm
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.shardformer.layer.parallel_module import ParallelModule
|
||||
from colossalai.tensor.d_tensor import distribute_tensor, is_distributed_tensor
|
||||
|
@ -36,14 +31,6 @@ inference_ops = InferenceOpsLoader().load()
|
|||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
|
||||
use_flash_attn2 = True
|
||||
except ImportError:
|
||||
use_flash_attn2 = False
|
||||
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
|
||||
|
||||
|
||||
def llama_causal_lm_forward(
|
||||
self: LlamaForCausalLM,
|
||||
|
@ -126,8 +113,8 @@ def llama_model_forward(
|
|||
cos_sin = (self._cos_cached[rotary_indexes], self._sin_cached[rotary_indexes])
|
||||
|
||||
elif use_cuda_kernel:
|
||||
if inputmetadata.dtype != torch.float32 and use_flash_attn2:
|
||||
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
|
||||
if can_use_flash_attn2(inputmetadata.dtype):
|
||||
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.int32), (1, 0))
|
||||
|
||||
hidden_dim = self._cos_cached.size(-1)
|
||||
total_length = hidden_states.size(0)
|
||||
|
@ -238,7 +225,6 @@ def llama_decoder_layer_forward(
|
|||
kv_seq_len=kv_seq_len,
|
||||
output_tensor=output_tensor,
|
||||
sm_scale=sm_scale,
|
||||
use_cuda_kernel=use_cuda_kernel,
|
||||
cu_seqlens=cu_seqlens,
|
||||
high_precision=high_precision,
|
||||
)
|
||||
|
@ -279,7 +265,7 @@ class NopadLlamaMLP(LlamaMLP, ParallelModule):
|
|||
mlp_dproj: ParallelModule = None,
|
||||
process_group: ProcessGroup = None,
|
||||
):
|
||||
"""A Unified Layer for
|
||||
"""Replacement of LlamaMLP layer.
|
||||
|
||||
Args:
|
||||
config (LlamaConfig): Holding the Llama model config.
|
||||
|
@ -402,6 +388,7 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
|
|||
attn_vproj_w: torch.Tensor = None,
|
||||
attn_oproj: ParallelModule = None,
|
||||
process_group: ProcessGroup = None,
|
||||
model_shard_infer_config: ModelShardInferenceConfig = None,
|
||||
num_heads: int = None,
|
||||
hidden_size: int = None,
|
||||
num_key_value_heads: int = None,
|
||||
|
@ -433,6 +420,9 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
|
|||
self.rope_theta = config.rope_theta
|
||||
self.is_causal = True
|
||||
|
||||
self.attention_backend = get_attention_backend(model_shard_infer_config)
|
||||
self.pre_attention_backend = get_pre_attention_backend(model_shard_infer_config)
|
||||
|
||||
if self.num_heads == self.num_key_value_heads:
|
||||
qkv_weight_list = [attn_qproj_w.transpose(0, 1), attn_kproj_w.transpose(0, 1), attn_vproj_w.transpose(0, 1)]
|
||||
self.qkv_weight = nn.Parameter(torch.stack(qkv_weight_list, dim=0))
|
||||
|
@ -462,6 +452,7 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
|
|||
attn_vproj_w = module.v_proj.weight
|
||||
assert is_distributed_tensor(attn_qproj_w), "attn_qproj_w must be dist tensor"
|
||||
attn_oproj = module.o_proj
|
||||
model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
|
||||
|
||||
attn_layer = NopadLlamaAttention(
|
||||
config=config,
|
||||
|
@ -471,6 +462,7 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
|
|||
attn_vproj_w=attn_vproj_w,
|
||||
attn_oproj=attn_oproj,
|
||||
process_group=process_group,
|
||||
model_shard_infer_config=model_shard_infer_config,
|
||||
num_heads=module.num_heads,
|
||||
hidden_size=module.hidden_size,
|
||||
num_key_value_heads=module.num_key_value_heads,
|
||||
|
@ -533,111 +525,50 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
|
|||
|
||||
block_size = k_cache.size(-2)
|
||||
|
||||
if is_prompts:
|
||||
if not is_verifier and use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2:
|
||||
# flash attn 2 currently only supports FP16/BF16.
|
||||
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision)
|
||||
inference_ops.context_kv_cache_memcpy(
|
||||
key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len
|
||||
)
|
||||
attn_metadata = AttentionMetaData(
|
||||
query_states=query_states,
|
||||
key_states=key_states,
|
||||
value_states=value_states,
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
block_tables=block_tables,
|
||||
block_size=block_size,
|
||||
kv_seq_len=kv_seq_len,
|
||||
sequence_lengths=sequence_lengths,
|
||||
sm_scale=sm_scale,
|
||||
alibi_slopes=None,
|
||||
cu_seqlens=cu_seqlens,
|
||||
output_tensor=output_tensor,
|
||||
use_spec_dec=is_verifier,
|
||||
use_alibi_attn=False,
|
||||
)
|
||||
|
||||
attn_output = flash_attn_varlen_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=kv_seq_len,
|
||||
max_seqlen_k=kv_seq_len,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=sm_scale,
|
||||
causal=True,
|
||||
)
|
||||
attn_output = attn_output.view(token_nums, -1)
|
||||
else:
|
||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||
attn_output = context_attention_unpadded(
|
||||
q=query_states,
|
||||
k=key_states,
|
||||
v=value_states,
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
context_lengths=sequence_lengths,
|
||||
block_tables=block_tables,
|
||||
block_size=block_size,
|
||||
output=output_tensor,
|
||||
max_seq_len=kv_seq_len,
|
||||
sm_scale=sm_scale,
|
||||
use_new_kcache_layout=use_cuda_kernel,
|
||||
)
|
||||
else:
|
||||
if is_prompts: # prefilling stage
|
||||
self.pre_attention_backend.prefill(
|
||||
attn_metadata,
|
||||
cos=cos_sin[0],
|
||||
sin=cos_sin[1],
|
||||
high_precision=high_precision,
|
||||
)
|
||||
attn_output = self.attention_backend.prefill(
|
||||
attn_metadata,
|
||||
token_nums=token_nums,
|
||||
)
|
||||
else: # decoding stage
|
||||
q_len = tokens_to_verify + 1 if is_verifier else 1
|
||||
|
||||
if use_cuda_kernel:
|
||||
inference_ops.rotary_embedding_and_cache_copy(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cos_sin[0],
|
||||
cos_sin[1],
|
||||
k_cache,
|
||||
v_cache,
|
||||
sequence_lengths,
|
||||
block_tables,
|
||||
high_precision,
|
||||
)
|
||||
inference_ops.flash_decoding_attention(
|
||||
output_tensor,
|
||||
query_states,
|
||||
k_cache,
|
||||
v_cache,
|
||||
sequence_lengths,
|
||||
block_tables,
|
||||
block_size,
|
||||
kv_seq_len,
|
||||
fd_inter_tensor.mid_output,
|
||||
fd_inter_tensor.exp_sums,
|
||||
fd_inter_tensor.max_logits,
|
||||
None,
|
||||
sm_scale,
|
||||
)
|
||||
attn_output = output_tensor
|
||||
else:
|
||||
if is_verifier:
|
||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||
copy_k_to_blocked_cache(
|
||||
key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
|
||||
)
|
||||
copy_k_to_blocked_cache(
|
||||
value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
|
||||
)
|
||||
else:
|
||||
decoding_fused_rotary_embedding(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cos_sin[0],
|
||||
cos_sin[1],
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_tables,
|
||||
sequence_lengths,
|
||||
)
|
||||
attn_output = flash_decoding_attention(
|
||||
q=query_states,
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
kv_seq_len=sequence_lengths,
|
||||
block_tables=block_tables,
|
||||
block_size=block_size,
|
||||
max_seq_len_in_batch=kv_seq_len,
|
||||
output=output_tensor,
|
||||
mid_output=fd_inter_tensor.mid_output,
|
||||
mid_output_lse=fd_inter_tensor.mid_output_lse,
|
||||
sm_scale=sm_scale,
|
||||
kv_group_num=self.num_key_value_groups,
|
||||
q_len=q_len,
|
||||
)
|
||||
self.pre_attention_backend.decode(
|
||||
attn_metadata,
|
||||
cos=cos_sin[0],
|
||||
sin=cos_sin[1],
|
||||
q_len=q_len,
|
||||
)
|
||||
attn_output = self.attention_backend.decode(
|
||||
attn_metadata,
|
||||
fd_inter_tensor=fd_inter_tensor,
|
||||
num_key_value_groups=self.num_key_value_groups,
|
||||
q_len=q_len,
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(-1, self.hidden_size)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
|
|
@ -1,8 +1,5 @@
|
|||
from colossalai.inference.config import RPC_PARAM
|
||||
from colossalai.inference.modeling.layers.baichuan_tp_linear import (
|
||||
BaichuanLMHeadLinear1D_Col,
|
||||
BaichuanWpackLinear1D_Col,
|
||||
)
|
||||
from colossalai.inference.modeling.layers.baichuan_tp_linear import BaichuanLMHeadLinear1D_Col
|
||||
from colossalai.inference.modeling.models.nopadding_baichuan import (
|
||||
NopadBaichuanAttention,
|
||||
NopadBaichuanMLP,
|
||||
|
@ -14,7 +11,7 @@ from colossalai.inference.modeling.models.nopadding_llama import (
|
|||
llama_model_forward,
|
||||
)
|
||||
from colossalai.inference.utils import init_to_get_rotary
|
||||
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
|
||||
from colossalai.shardformer.layer import FusedLinear1D_Col, Linear1D_Col, Linear1D_Row
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
|
||||
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
||||
|
||||
|
@ -60,8 +57,7 @@ class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM):
|
|||
target_module=NopadBaichuanMLP,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.W_pack",
|
||||
target_module=BaichuanWpackLinear1D_Col,
|
||||
suffix="self_attn.W_pack", target_module=FusedLinear1D_Col, kwargs={"n_fused": 3}
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
|
@ -70,6 +66,9 @@ class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="self_attn",
|
||||
target_module=NopadBaichuanAttention,
|
||||
kwargs={
|
||||
"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"],
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
|
|
@ -72,6 +72,9 @@ class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="self_attn",
|
||||
target_module=NopadLlamaAttention,
|
||||
kwargs={
|
||||
"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"],
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
|
|
@ -46,6 +46,7 @@ class GlideInput:
|
|||
large_k_cache: torch.Tensor = None
|
||||
large_v_cache: torch.Tensor = None
|
||||
sequence_lengths: torch.Tensor = None
|
||||
n_spec_tokens: int = 5
|
||||
|
||||
@property
|
||||
def glimpse_ready(self):
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Utils for model inference
|
||||
"""
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
@ -9,8 +10,11 @@ from typing import Optional, Tuple
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.testing import free_port
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
|
||||
def init_to_get_rotary(self, base=10000, use_elem=False):
|
||||
"""
|
||||
|
@ -113,3 +117,44 @@ def find_available_ports(num: int):
|
|||
print(f"An OS error occurred: {e}")
|
||||
raise RuntimeError("Error finding available ports")
|
||||
return free_ports
|
||||
|
||||
|
||||
def get_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor:
|
||||
"""
|
||||
Alibi slopes calculation adapted from https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/models/bloom/modeling_bloom.py#L57
|
||||
|
||||
Args:
|
||||
num_heads (int): The number of attention heads.
|
||||
device (torch.device): The device to use.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The Alibi slopes.
|
||||
"""
|
||||
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
|
||||
base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, device=device)
|
||||
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32, device=device)
|
||||
slopes = torch.pow(base, powers)
|
||||
if closest_power_of_2 != num_heads:
|
||||
extra_base = torch.tensor(
|
||||
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, device=device
|
||||
)
|
||||
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
|
||||
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32, device=device)
|
||||
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
||||
return slopes
|
||||
|
||||
|
||||
def can_use_flash_attn2(dtype: torch.dtype) -> bool:
|
||||
"""
|
||||
Check flash attention2 availability.
|
||||
"""
|
||||
if dtype not in (torch.float16, torch.bfloat16):
|
||||
return False
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_func # noqa
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
|
||||
return False
|
||||
|
|
|
@ -45,7 +45,10 @@ def launch(
|
|||
backend = cur_accelerator.communication_backend
|
||||
|
||||
# init default process group
|
||||
init_method = f"tcp://[{host}]:{port}"
|
||||
if ":" in host: # IPv6
|
||||
init_method = f"tcp://[{host}]:{port}"
|
||||
else: # IPv4
|
||||
init_method = f"tcp://{host}:{port}"
|
||||
dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)
|
||||
|
||||
# set cuda device
|
||||
|
|
|
@ -50,7 +50,7 @@ def get_pad_info(padding_mask: torch.Tensor) -> Tuple[int, torch.Tensor, torch.T
|
|||
seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
|
||||
indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
|
||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
||||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
||||
return max_seqlen_in_batch, cu_seqlens, indices
|
||||
|
||||
|
||||
|
|
|
@ -140,32 +140,29 @@ class RMSNorm(BaseLayerNorm):
|
|||
|
||||
class LayerNorm(BaseLayerNorm):
|
||||
r"""
|
||||
This is a wrapper around the torch.nn.LayerNorm. It is meant to be used only with the from_native_module interface.
|
||||
This is a wrapper around native LayerNorm. It is meant to be used only with the from_native_module interface.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"LayerNorm is not implemented as a physical class. "
|
||||
"It is meant to be used only with the from_native_module interface to convert a native pytorch layer norm module to colossalai layer norm module."
|
||||
"It is meant to be used only with the from_native_module interface to convert a native LayerNorm module to colossalai layer norm module."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
|
||||
def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
|
||||
r"""
|
||||
Convert a native pytorch layer norm module to colossalai layer norm module,
|
||||
Convert a native LayerNorm module to colossalai layer norm module,
|
||||
and optionally marking parameters for gradient aggregation.
|
||||
|
||||
Args:
|
||||
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
|
||||
module (nn.Module): The native LayerNorm module to be converted.
|
||||
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
|
||||
|
||||
Returns:
|
||||
nn.Module: The LayerNorm module.
|
||||
nn.Module: The colossalai LayerNorm module.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the provided module is not an instance of nn.LayerNorm.
|
||||
"""
|
||||
assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm."
|
||||
|
||||
LazyInitContext.materialize(module)
|
||||
|
||||
|
@ -174,7 +171,8 @@ class LayerNorm(BaseLayerNorm):
|
|||
# aggregation of these gradients is necessary during backpropagation.
|
||||
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
|
||||
SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight)
|
||||
SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias)
|
||||
if module.bias is not None:
|
||||
SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias)
|
||||
|
||||
return module
|
||||
|
||||
|
@ -187,31 +185,29 @@ class FusedLayerNorm(BaseLayerNorm):
|
|||
def __init__(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"FusedLayerNorm is not implemented as a physical class. "
|
||||
"It is meant to be used only with the from_native_module interface convert a native pytorch layer norm module to FusedLayerNorm module provided by apex."
|
||||
"It is meant to be used only with the from_native_module interface convert a native LayerNorm module to FusedLayerNorm module provided by apex."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
|
||||
r"""
|
||||
Convert a native pytorch layer norm module to FusedLayerNorm module provided by apex,
|
||||
Convert a native LayerNorm module to FusedLayerNorm module provided by apex,
|
||||
and optionally marking parameters for gradient aggregation.
|
||||
|
||||
Args:
|
||||
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
|
||||
module (nn.Module): The native LayerNorm module to be converted.
|
||||
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
|
||||
|
||||
Returns:
|
||||
nn.Module: Union[FastLayerNorm, FusedLayerNorm].
|
||||
|
||||
Raises:
|
||||
AssertionError: If the provided module is not an instance of nn.LayerNorm.
|
||||
"""
|
||||
|
||||
LazyInitContext.materialize(module)
|
||||
# get the attributes of the module
|
||||
normalized_shape = module.normalized_shape
|
||||
eps = module.eps
|
||||
elementwise_affine = module.elementwise_affine
|
||||
normalized_shape = getattr(module, "normalized_shape", module.weight.shape[0])
|
||||
eps = module.variance_epsilon if hasattr(module, "variance_epsilon") else module.eps
|
||||
elementwise_affine = getattr(module, "elementwise_affine", True)
|
||||
dtype = module.weight.dtype
|
||||
device = module.weight.device
|
||||
|
||||
|
@ -229,7 +225,7 @@ class FusedLayerNorm(BaseLayerNorm):
|
|||
ApexFusedLayerNorm = FusedLayerNormWithHook
|
||||
except NameError:
|
||||
warnings.warn(
|
||||
"Please install Apex from source to use fused kernels, or set self.enable_fused_normalization = False. Using vanilla layernorm instead."
|
||||
"Please install Apex from source to use fused kernels, or set self.enable_fused_normalization = False. Using native layernorm instead."
|
||||
)
|
||||
return module
|
||||
|
||||
|
@ -237,7 +233,8 @@ class FusedLayerNorm(BaseLayerNorm):
|
|||
ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device)
|
||||
)
|
||||
layernorm.weight = module.weight
|
||||
layernorm.bias = module.bias
|
||||
if module.bias is not None:
|
||||
layernorm.bias = module.bias
|
||||
|
||||
if sp_partial_derived:
|
||||
# Since gradients are computed using only a subset of the data,
|
||||
|
|
|
@ -475,7 +475,10 @@ class BloomPipelineForwards:
|
|||
sequence_lengths = -1
|
||||
else:
|
||||
if input_ids is not None:
|
||||
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
|
||||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||
sequence_lengths = sequence_lengths.to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
logger.warning(
|
||||
|
|
|
@ -0,0 +1,692 @@
|
|||
import math
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers.cache_utils import Cache, DynamicCache
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from transformers.models.cohere.modeling_cohere import (
|
||||
CohereForCausalLM,
|
||||
CohereModel,
|
||||
StaticCache,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.layer._operation import (
|
||||
all_to_all_comm,
|
||||
gather_forward_split_backward,
|
||||
split_forward_gather_backward,
|
||||
)
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from ..layer import ColoAttention, cross_entropy_1d
|
||||
|
||||
|
||||
class CommandPipelineForwards:
|
||||
"""
|
||||
This class serves as a micro library for forward function substitution of Command models
|
||||
under pipeline setting.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def command_model_forward(
|
||||
self: CohereModel,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
):
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with pipeline parallelism. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if stage_manager.is_first_stage():
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape[:2]
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape[:2]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
batch_size, seq_length = input_shape
|
||||
device = hidden_states.device
|
||||
|
||||
past_seen_tokens = 0
|
||||
if use_cache: # kept for BC (cache positions)
|
||||
if not isinstance(past_key_values, StaticCache):
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
past_seen_tokens = past_key_values.get_seq_length()
|
||||
if cache_position is None:
|
||||
if isinstance(past_key_values, StaticCache):
|
||||
raise ValueError("cache_position is a required argument when using StaticCache.")
|
||||
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=device)
|
||||
|
||||
seq_length_with_past = seq_length + past_seen_tokens
|
||||
|
||||
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||
if output_attentions:
|
||||
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||
output_hidden_states = False
|
||||
if use_cache:
|
||||
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
|
||||
use_cache = False
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
# embed positions, for the first stage, hidden_states is the input embeddings,
|
||||
# for the other stages, hidden_states is the output of the previous stage
|
||||
if shard_config.enable_flash_attention:
|
||||
# in this case, attention_mask is a dict rather than a tensor
|
||||
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
|
||||
attention_mask = ColoAttention.prepare_attn_kwargs(
|
||||
mask_shape,
|
||||
hidden_states.dtype,
|
||||
hidden_states.device,
|
||||
q_padding_mask=attention_mask,
|
||||
is_causal=True,
|
||||
)
|
||||
else:
|
||||
attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position)
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
num_ckpt_layers = 0
|
||||
if self.gradient_checkpointing and self.training:
|
||||
num_ckpt_layers = end_idx - start_idx
|
||||
# TODO: We can replace `gradient_checkpointing_enable` fn and initialize a gradient_checkpointing (List[bool]) for each layer
|
||||
if shard_config.gradient_checkpoint_config is not None:
|
||||
num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
|
||||
stage=stage_manager.stage,
|
||||
num_stages=stage_manager.num_stages,
|
||||
num_layers=end_idx - start_idx,
|
||||
model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0),
|
||||
num_model_chunks=stage_manager.num_model_chunks,
|
||||
)
|
||||
assert num_ckpt_layers <= end_idx - start_idx
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if idx - start_idx < num_ckpt_layers:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if stage_manager.is_last_stage():
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_cache,
|
||||
all_hidden_states,
|
||||
all_self_attns,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
# always return dict for imediate stage
|
||||
return {"hidden_states": hidden_states}
|
||||
|
||||
@staticmethod
|
||||
def command_for_causal_lm_forward(
|
||||
self: CohereForCausalLM,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, CohereForCausalLM
|
||||
|
||||
>>> model = CohereForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
logger = logging.get_logger(__name__)
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||
if output_attentions:
|
||||
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||
output_hidden_states = False
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = CommandPipelineForwards.command_model_forward(
|
||||
self.model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config,
|
||||
)
|
||||
past_key_values = None
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits * self.logit_scale
|
||||
logits = logits.float()
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
|
||||
new_vocab_size = logits.shape[-1]
|
||||
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits,
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=self.lm_head.out_features,
|
||||
dtype=self.model.dtype,
|
||||
)
|
||||
else:
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
else:
|
||||
hidden_states = outputs.get("hidden_states")
|
||||
return {"hidden_states": hidden_states}
|
||||
|
||||
|
||||
def get_command_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
||||
if sp_mode is not None:
|
||||
assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode"
|
||||
assert (sp_size is not None) and (
|
||||
sp_group is not None
|
||||
), "Must specify sp_size and sp_group for sequence parallel"
|
||||
if "padding_mask" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
# sp: modify sp_len when sequence parallel mode is ring
|
||||
if sp_mode in ["split_gather", "ring"]:
|
||||
q_len *= sp_size
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
query_states = all_to_all_comm(query_states, sp_group)
|
||||
key_states = all_to_all_comm(key_states, sp_group)
|
||||
value_states = all_to_all_comm(value_states, sp_group)
|
||||
bsz, q_len, _ = query_states.size()
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if shard_config.enable_flash_attention:
|
||||
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
|
||||
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
|
||||
else:
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
|
||||
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2)
|
||||
else:
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_command_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError(
|
||||
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
past_seen_tokens = 0
|
||||
seq_len = inputs_embeds.shape[1]
|
||||
if use_cache: # kept for BC (cache positions)
|
||||
if not isinstance(past_key_values, StaticCache):
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
past_seen_tokens = past_key_values.get_seq_length()
|
||||
if cache_position is None:
|
||||
if isinstance(past_key_values, StaticCache):
|
||||
raise ValueError("cache_position is a required argument when using StaticCache.")
|
||||
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_len, device=inputs_embeds.device)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
# in this case, attention_mask is a dict rather than a tensor
|
||||
if shard_config.enable_flash_attention:
|
||||
mask_shape = (inputs_embeds.shape[0], 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len)
|
||||
attention_mask = ColoAttention.prepare_attn_kwargs(
|
||||
mask_shape,
|
||||
inputs_embeds.dtype,
|
||||
inputs_embeds.device,
|
||||
q_padding_mask=attention_mask,
|
||||
is_causal=True,
|
||||
)
|
||||
else:
|
||||
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
|
||||
|
||||
if sp_mode in ["ring", "split_gather"]:
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
|
||||
elif sp_mode == "all_to_all":
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
)
|
||||
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if sp_mode == "ring" or sp_mode == "split_gather":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = None
|
||||
if use_cache:
|
||||
next_cache = (
|
||||
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
|
||||
)
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||
from transformers import CohereForCausalLM
|
||||
|
||||
def forward(
|
||||
self: CohereForCausalLM,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, CohereForCausalLM
|
||||
|
||||
>>> model = CohereForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits * self.logit_scale
|
||||
logits = logits.float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
new_vocab_size = logits.shape[-1]
|
||||
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits,
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=self.lm_head.out_features,
|
||||
dtype=self.model.dtype,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
return forward
|
|
@ -291,18 +291,17 @@ class FalconPipelineForwards:
|
|||
if attention_mask_2d is None:
|
||||
attention_mask = alibi / math.sqrt(self.config.hidden_size // self.num_heads)
|
||||
else:
|
||||
min_dtype = torch.finfo(alibi.dtype).min
|
||||
attention_mask = torch.masked_fill(
|
||||
alibi / math.sqrt(self.config.hidden_size // self.num_heads),
|
||||
attention_mask < -1,
|
||||
torch.finfo(alibi.dtype).min,
|
||||
min_dtype,
|
||||
)
|
||||
|
||||
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
|
||||
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
if seq_length > 1:
|
||||
attention_mask = AttentionMaskConverter._unmask_unattended(
|
||||
attention_mask, attention_mask_2d, unmasked_value=0.0
|
||||
)
|
||||
if seq_length > 1 and attention_mask.device.type == "cuda":
|
||||
attention_mask = AttentionMaskConverter._unmask_unattended(attention_mask, min_dtype=min_dtype)
|
||||
else:
|
||||
# PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case.
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
|
@ -543,7 +542,10 @@ class FalconPipelineForwards:
|
|||
sequence_lengths = -1
|
||||
else:
|
||||
if input_ids is not None:
|
||||
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(dim=-1) - 1).to(logits.device)
|
||||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||
sequence_lengths = sequence_lengths.to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
logger.warning(
|
||||
|
|
|
@ -738,7 +738,10 @@ class GPT2PipelineForwards:
|
|||
sequence_lengths = -1
|
||||
else:
|
||||
if input_ids is not None:
|
||||
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
|
||||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||
sequence_lengths = sequence_lengths.to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
logger.warning_once(
|
||||
|
|
|
@ -32,6 +32,7 @@ def _get_attention_mask(
|
|||
hidden_states: torch.Tensor,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]],
|
||||
attention_mask: Optional[torch.FloatTensor],
|
||||
use_flash_attention_2: bool = False,
|
||||
) -> Optional[Union[torch.Tensor, dict]]:
|
||||
batch_size, seq_len = hidden_states.shape[:2]
|
||||
past_key_values_length = 0
|
||||
|
@ -47,7 +48,7 @@ def _get_attention_mask(
|
|||
attention_mask,
|
||||
is_causal=True,
|
||||
)
|
||||
elif attention_mask is not None:
|
||||
elif use_flash_attention_2 and attention_mask is not None:
|
||||
if batch_size <= 0:
|
||||
raise ValueError("batch_size has to be defined and > 0")
|
||||
attention_mask = attention_mask.view(batch_size, -1)
|
||||
|
@ -162,7 +163,9 @@ class GPTJPipelineForwards:
|
|||
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
|
||||
attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask)
|
||||
attention_mask = _get_attention_mask(
|
||||
self, shard_config, hidden_states, past_key_values, attention_mask, self._use_flash_attention_2
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
|
@ -419,7 +422,10 @@ class GPTJPipelineForwards:
|
|||
sequence_lengths = -1
|
||||
else:
|
||||
if input_ids is not None:
|
||||
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
|
||||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||
sequence_lengths = sequence_lengths.to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
logger.warning_once(
|
||||
|
@ -712,7 +718,9 @@ def gptj_model_forward_for_flash_attention(shard_config: ShardConfig):
|
|||
|
||||
hidden_states = self.drop(hidden_states)
|
||||
|
||||
attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask)
|
||||
attention_mask = _get_attention_mask(
|
||||
self, shard_config, hidden_states, past_key_values, attention_mask, self._use_flash_attention_2
|
||||
)
|
||||
|
||||
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
|
||||
|
||||
|
@ -886,7 +894,9 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||
hidden_states = self.drop(hidden_states)
|
||||
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask)
|
||||
attention_mask = _get_attention_mask(
|
||||
self, shard_config, hidden_states, past_key_values, attention_mask, self._use_flash_attention_2
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
|
|
|
@ -7,11 +7,7 @@ import torch.nn.functional as F
|
|||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_attn_mask_utils import (
|
||||
_prepare_4d_causal_attention_mask,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
from transformers.cache_utils import Cache, DynamicCache
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
|
@ -21,6 +17,7 @@ from transformers.models.llama.modeling_llama import (
|
|||
LlamaForCausalLM,
|
||||
LlamaForSequenceClassification,
|
||||
LlamaModel,
|
||||
StaticCache,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
|
@ -55,6 +52,7 @@ class LlamaPipelineForwards:
|
|||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
|
@ -67,6 +65,11 @@ class LlamaPipelineForwards:
|
|||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with pipeline parallelism. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
|
@ -83,14 +86,24 @@ class LlamaPipelineForwards:
|
|||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
batch_size, seq_length = input_shape
|
||||
device = hidden_states.device
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
past_seen_tokens = 0
|
||||
if use_cache: # kept for BC (cache positions)
|
||||
if not isinstance(past_key_values, StaticCache):
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
past_seen_tokens = past_key_values.get_seq_length()
|
||||
if cache_position is None:
|
||||
if isinstance(past_key_values, StaticCache):
|
||||
raise ValueError("cache_position is a required argument when using StaticCache.")
|
||||
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=device)
|
||||
|
||||
seq_length_with_past = seq_length + past_seen_tokens
|
||||
|
||||
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||
if output_attentions:
|
||||
|
@ -103,18 +116,8 @@ class LlamaPipelineForwards:
|
|||
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
|
||||
use_cache = False
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0)
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
# embed positions, for the first stage, hidden_states is the input embeddings,
|
||||
# for the other stages, hidden_states is the output of the previous stage
|
||||
|
@ -129,28 +132,9 @@ class LlamaPipelineForwards:
|
|||
is_causal=True,
|
||||
)
|
||||
else:
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
elif self._use_sdpa and not output_attentions:
|
||||
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
||||
# the manual implementation that requires a 4D causal mask in all cases.
|
||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
)
|
||||
else:
|
||||
# 4d mask is passed through the layers
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
hidden_states,
|
||||
past_key_values_length,
|
||||
)
|
||||
attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
|
@ -190,6 +174,7 @@ class LlamaPipelineForwards:
|
|||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
|
@ -199,6 +184,7 @@ class LlamaPipelineForwards:
|
|||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
@ -249,6 +235,7 @@ class LlamaPipelineForwards:
|
|||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
|
@ -306,6 +293,7 @@ class LlamaPipelineForwards:
|
|||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
|
@ -368,6 +356,7 @@ class LlamaPipelineForwards:
|
|||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
|
@ -401,6 +390,7 @@ class LlamaPipelineForwards:
|
|||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
|
@ -470,36 +460,53 @@ class LlamaPipelineForwards:
|
|||
return {"hidden_states": hidden_states}
|
||||
|
||||
|
||||
def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size):
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
||||
|
||||
try:
|
||||
from transformers.models.llama.modeling_llama import repeat_kv
|
||||
except:
|
||||
warnings.warn("using llamav1, llamav1 hasn't repeat_kv function")
|
||||
|
||||
def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
|
||||
def forward(
|
||||
self: LlamaAttention,
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[dict] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
||||
if sp_mode is not None:
|
||||
assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode"
|
||||
assert (sp_size is not None) and (
|
||||
sp_group is not None
|
||||
), "Must specify sp_size and sp_group for sequence parallel"
|
||||
if "padding_mask" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||
)
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
# sp: modify sp_len when sequence parallel mode is ring
|
||||
if sp_mode in ["split_gather", "ring"]:
|
||||
q_len *= sp_size
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
if self.config.pretraining_tp > 1:
|
||||
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
||||
query_slices = self.q_proj.weight.split(
|
||||
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
||||
)
|
||||
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
||||
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
||||
|
||||
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
query_states = torch.cat(query_states, dim=-1)
|
||||
|
||||
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
key_states = torch.cat(key_states, dim=-1)
|
||||
|
||||
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
value_states = torch.cat(value_states, dim=-1)
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
|
@ -520,39 +527,76 @@ def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size):
|
|||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
|
||||
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
if shard_config.enable_flash_attention:
|
||||
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
|
||||
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
|
||||
else:
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
|
||||
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
else:
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
if self.config.pretraining_tp > 1:
|
||||
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
||||
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
||||
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
|
||||
else:
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig):
|
||||
def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
|
||||
logger = logging.get_logger(__name__)
|
||||
assert shard_config.enable_flash_attention, "Flash Attention is not enabled."
|
||||
|
||||
def forward(
|
||||
self: LlamaModel,
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
|
@ -562,119 +606,122 @@ def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig):
|
|||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError(
|
||||
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# in this case, attention_mask is a dict rather than a tensor
|
||||
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
|
||||
attention_mask = ColoAttention.prepare_attn_kwargs(
|
||||
mask_shape,
|
||||
hidden_states.dtype,
|
||||
hidden_states.device,
|
||||
q_padding_mask=attention_mask,
|
||||
is_causal=True,
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
past_seen_tokens = 0
|
||||
seq_len = inputs_embeds.shape[1]
|
||||
if use_cache: # kept for BC (cache positions)
|
||||
if not isinstance(past_key_values, StaticCache):
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
past_seen_tokens = past_key_values.get_seq_length()
|
||||
if cache_position is None:
|
||||
if isinstance(past_key_values, StaticCache):
|
||||
raise ValueError("cache_position is a required argument when using StaticCache.")
|
||||
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_len, device=inputs_embeds.device)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
# in this case, attention_mask is a dict rather than a tensor
|
||||
if shard_config.enable_flash_attention:
|
||||
mask_shape = (inputs_embeds.shape[0], 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len)
|
||||
attention_mask = ColoAttention.prepare_attn_kwargs(
|
||||
mask_shape,
|
||||
inputs_embeds.dtype,
|
||||
inputs_embeds.device,
|
||||
q_padding_mask=attention_mask,
|
||||
is_causal=True,
|
||||
)
|
||||
else:
|
||||
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
|
||||
|
||||
if sp_mode in ["ring", "split_gather"]:
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
|
||||
elif sp_mode == "all_to_all":
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
next_decoder_cache = None
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, past_key_value, output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
)
|
||||
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if sp_mode == "ring" or sp_mode == "split_gather":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
next_cache = None
|
||||
if use_cache:
|
||||
next_cache = (
|
||||
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
|
||||
)
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
|
@ -700,6 +747,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
|
@ -744,6 +792,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
|
@ -786,266 +835,3 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||
)
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
# sp: modify sp_len when sequence parallel mode is ring
|
||||
if sp_mode in ["split_gather", "ring"]:
|
||||
q_len *= sp_size
|
||||
if self.config.pretraining_tp > 1:
|
||||
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
||||
query_slices = self.q_proj.weight.split(
|
||||
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
||||
)
|
||||
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
||||
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
||||
|
||||
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
query_states = torch.cat(query_states, dim=-1)
|
||||
|
||||
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
key_states = torch.cat(key_states, dim=-1)
|
||||
|
||||
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
value_states = torch.cat(value_states, dim=-1)
|
||||
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
query_states = all_to_all_comm(query_states, sp_group)
|
||||
key_states = all_to_all_comm(key_states, sp_group)
|
||||
value_states = all_to_all_comm(value_states, sp_group)
|
||||
bsz, q_len, _ = query_states.size()
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
|
||||
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2)
|
||||
else:
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
||||
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
||||
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
|
||||
else:
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
# modify past_key_values_length when using sequence parallel
|
||||
past_key_values_length *= sp_size
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if sp_mode in ["ring", "split_gather"]:
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
|
||||
elif sp_mode == "all_to_all":
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past),
|
||||
dtype=torch.bool,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, attention_mask.shape, inputs_embeds, past_key_values_length
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, past_key_value, output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
)
|
||||
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if sp_mode == "ring" or sp_mode == "split_gather":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
return forward
|
||||
|
|
|
@ -4,7 +4,10 @@ from typing import List, Optional, Tuple, Union
|
|||
import torch
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from transformers.cache_utils import Cache, DynamicCache
|
||||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||
from transformers.modeling_attn_mask_utils import (
|
||||
_prepare_4d_causal_attention_mask,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
|
@ -77,7 +80,7 @@ class MistralForwards:
|
|||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
if attention_mask is not None and self._use_flash_attention_2 and use_cache:
|
||||
if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
|
||||
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
||||
if is_padding_right:
|
||||
raise ValueError(
|
||||
|
@ -97,9 +100,18 @@ class MistralForwards:
|
|||
is_causal=True,
|
||||
)
|
||||
else:
|
||||
if self._use_flash_attention_2:
|
||||
if self._attn_implementation == "flash_attention_2":
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
elif self._attn_implementation == "sdpa" and not output_attentions:
|
||||
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
||||
# the manual implementation that requires a 4D causal mask in all cases.
|
||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
)
|
||||
else:
|
||||
# 4d mask is passed through the layers
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
|
@ -462,7 +474,7 @@ def get_mistral_model_forward_for_flash_attn(shard_config: ShardConfig):
|
|||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if attention_mask is not None and self._use_flash_attention_2 and use_cache:
|
||||
if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
|
||||
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
||||
if is_padding_right:
|
||||
raise ValueError(
|
||||
|
@ -481,9 +493,18 @@ def get_mistral_model_forward_for_flash_attn(shard_config: ShardConfig):
|
|||
is_causal=True,
|
||||
)
|
||||
else:
|
||||
if self._use_flash_attention_2:
|
||||
if self._attn_implementation == "flash_attention_2":
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
elif self._attn_implementation == "sdpa" and not output_attentions:
|
||||
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
||||
# the manual implementation that requires a 4D causal mask in all cases.
|
||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
)
|
||||
else:
|
||||
# 4d mask is passed through the layers
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
|
|
|
@ -17,6 +17,7 @@ from transformers.modeling_outputs import (
|
|||
SequenceClassifierOutput,
|
||||
)
|
||||
from transformers.models.whisper.modeling_whisper import (
|
||||
_HIDDEN_STATES_START_POSITION,
|
||||
WhisperDecoder,
|
||||
WhisperEncoder,
|
||||
WhisperForAudioClassification,
|
||||
|
@ -166,6 +167,7 @@ def get_whisper_decoder_forward_for_flash_attention(shard_config: ShardConfig):
|
|||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
position_ids=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
|
@ -199,9 +201,13 @@ def get_whisper_decoder_forward_for_flash_attention(shard_config: ShardConfig):
|
|||
|
||||
# embed positions
|
||||
if input_ids is not None:
|
||||
positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
|
||||
positions = self.embed_positions(
|
||||
input_ids, past_key_values_length=past_key_values_length, position_ids=position_ids
|
||||
)
|
||||
else:
|
||||
positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length)
|
||||
positions = self.embed_positions(
|
||||
inputs_embeds, past_key_values_length=past_key_values_length, position_ids=position_ids
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds + positions
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
|
@ -599,6 +605,7 @@ class WhisperPipelineForwards:
|
|||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
position_ids=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
|
@ -716,9 +723,13 @@ class WhisperPipelineForwards:
|
|||
|
||||
# embed positions
|
||||
if input_ids is not None:
|
||||
positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
|
||||
positions = self.embed_positions(
|
||||
input_ids, past_key_values_length=past_key_values_length, position_ids=position_ids
|
||||
)
|
||||
else:
|
||||
positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length)
|
||||
positions = self.embed_positions(
|
||||
inputs_embeds, past_key_values_length=past_key_values_length, position_ids=position_ids
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds + positions
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
|
@ -841,6 +852,7 @@ class WhisperPipelineForwards:
|
|||
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
|
||||
decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
|
@ -944,6 +956,7 @@ class WhisperPipelineForwards:
|
|||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=decoder_inputs_embeds,
|
||||
position_ids=decoder_position_ids,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
|
@ -986,6 +999,7 @@ class WhisperPipelineForwards:
|
|||
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
|
||||
decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
|
@ -1048,6 +1062,7 @@ class WhisperPipelineForwards:
|
|||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
decoder_position_ids=decoder_position_ids,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
|
@ -1118,6 +1133,12 @@ class WhisperPipelineForwards:
|
|||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
if self.config.use_weighted_layer_sum:
|
||||
output_hidden_states = True
|
||||
elif output_hidden_states is None:
|
||||
output_hidden_states = self.config.output_hidden_states
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# audio_classification only holds encoder
|
||||
|
@ -1138,7 +1159,8 @@ class WhisperPipelineForwards:
|
|||
return encoder_outputs
|
||||
|
||||
if self.config.use_weighted_layer_sum:
|
||||
hidden_states = torch.stack(encoder_outputs, dim=1)
|
||||
hidden_states = encoder_outputs[_HIDDEN_STATES_START_POSITION]
|
||||
hidden_states = torch.stack(hidden_states, dim=1)
|
||||
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
||||
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
||||
else:
|
||||
|
|
|
@ -192,6 +192,13 @@ _POLICY_LIST = {
|
|||
"transformers.models.qwen2.modeling_qwen2.Qwen2ForSequenceClassification": PolicyLocation(
|
||||
file_name="qwen2", class_name="Qwen2ForSequenceClassificationPolicy"
|
||||
),
|
||||
# Command-R
|
||||
"transformers.models.cohere.modeling_cohere.CohereModel": PolicyLocation(
|
||||
file_name="command", class_name="CommandModelPolicy"
|
||||
),
|
||||
"transformers.models.cohere.modeling_cohere.CohereForCausalLM": PolicyLocation(
|
||||
file_name="command", class_name="CommandForCausalLMPolicy"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -67,7 +67,7 @@ class BertPolicy(Policy):
|
|||
else:
|
||||
norm_cls = col_nn.LayerNorm
|
||||
|
||||
sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None
|
||||
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
||||
assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for Bert"
|
||||
if sp_mode == "ring":
|
||||
warnings.warn(
|
||||
|
|
|
@ -50,7 +50,7 @@ class BloomPolicy(Policy):
|
|||
else:
|
||||
norm_cls = col_nn.LayerNorm
|
||||
|
||||
sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None
|
||||
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
||||
assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for BLOOM"
|
||||
if sp_mode == "ring":
|
||||
warnings.warn(
|
||||
|
|
|
@ -57,7 +57,7 @@ class ChatGLMPolicy(Policy):
|
|||
else:
|
||||
norm_cls = col_nn.LayerNorm
|
||||
|
||||
sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None
|
||||
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
||||
assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for ChatGLM2"
|
||||
if sp_mode == "ring":
|
||||
warnings.warn(
|
||||
|
|
|
@ -0,0 +1,369 @@
|
|||
import warnings
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, List, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
|
||||
from colossalai.shardformer.layer import (
|
||||
FusedLayerNorm,
|
||||
LayerNorm,
|
||||
Linear1D_Col,
|
||||
Linear1D_Row,
|
||||
PaddingEmbedding,
|
||||
PaddingLMHead,
|
||||
VocabParallelEmbedding1D,
|
||||
VocabParallelLMHead1D,
|
||||
)
|
||||
|
||||
from ..modeling.command import (
|
||||
CommandPipelineForwards,
|
||||
get_command_flash_attention_forward,
|
||||
get_command_flash_attention_model_forward,
|
||||
get_lm_forward_with_dist_cross_entropy,
|
||||
)
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ["CommandPolicy", "CommandForCausalLMPolicy"]
|
||||
|
||||
|
||||
class CommandPolicy(Policy):
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
|
||||
def preprocess(self):
|
||||
self.tie_weight = self.tie_weight_check()
|
||||
self.origin_attn_implement = self.model.config._attn_implementation
|
||||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
from transformers.models.cohere.modeling_cohere import (
|
||||
CohereAttention,
|
||||
CohereDecoderLayer,
|
||||
CohereFlashAttention2,
|
||||
CohereModel,
|
||||
CohereSdpaAttention,
|
||||
)
|
||||
|
||||
ATTN_IMPLEMENTATION = {
|
||||
"eager": CohereAttention,
|
||||
"flash_attention_2": CohereFlashAttention2,
|
||||
"sdpa": CohereSdpaAttention,
|
||||
}
|
||||
policy = {}
|
||||
|
||||
attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
|
||||
embedding_cls = None
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
embedding_cls = VocabParallelEmbedding1D
|
||||
else:
|
||||
if self.tie_weight:
|
||||
embedding_cls = PaddingEmbedding
|
||||
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
norm_cls = FusedLayerNorm
|
||||
else:
|
||||
norm_cls = LayerNorm
|
||||
|
||||
if self.pipeline_stage_manager is not None:
|
||||
self.shard_config.enable_sequence_parallelism = False
|
||||
self.shard_config.enable_sequence_overlap = False
|
||||
self.shard_config.sequence_parallelism_mode = None
|
||||
warnings.warn(
|
||||
f"For Command, sequence parallelism is currently not compatible with pipeline parallelism, set to be False"
|
||||
)
|
||||
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
||||
sp_size = self.shard_config.sequence_parallel_size or None
|
||||
sp_group = self.shard_config.sequence_parallel_process_group or None
|
||||
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
||||
|
||||
if sp_mode == "all_to_all":
|
||||
decoder_attribute_replacement = {
|
||||
"num_heads": self.model.config.num_attention_heads // sp_size,
|
||||
}
|
||||
if getattr(self.model.config, "num_key_value_heads", False):
|
||||
decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size
|
||||
|
||||
policy[attn_cls] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
)
|
||||
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_command_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=attn_cls,
|
||||
)
|
||||
if self.pipeline_stage_manager is None:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_command_flash_attention_model_forward(
|
||||
self.shard_config,
|
||||
sp_mode=sp_mode,
|
||||
sp_size=sp_size,
|
||||
sp_group=sp_group,
|
||||
),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=CohereModel,
|
||||
)
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
assert (
|
||||
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||
), f"The number of attention heads must be divisible by tensor parallel size."
|
||||
if hasattr(self.model.config, "num_key_value_heads"):
|
||||
assert (
|
||||
self.model.config.num_key_value_heads >= self.shard_config.tensor_parallel_size
|
||||
and self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
|
||||
), f"The number of key_value heads must be divisible by, and must not be less than tensor parallel size."
|
||||
decoder_attribute_replacement = {
|
||||
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
}
|
||||
if getattr(self.model.config, "num_key_value_heads", False):
|
||||
decoder_attribute_replacement["self_attn.num_key_value_heads"] = (
|
||||
self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size
|
||||
)
|
||||
|
||||
policy[CohereDecoderLayer] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.gate_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.up_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.down_proj",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
if embedding_cls is not None:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
),
|
||||
policy=policy,
|
||||
target_key=CohereModel,
|
||||
)
|
||||
|
||||
# optimization configuration
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="input_layernorm",
|
||||
target_module=norm_cls,
|
||||
kwargs={"sp_partial_derived": sp_partial_derived},
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=CohereDecoderLayer,
|
||||
)
|
||||
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="norm",
|
||||
target_module=norm_cls,
|
||||
kwargs={"sp_partial_derived": sp_partial_derived},
|
||||
),
|
||||
policy=policy,
|
||||
target_key=CohereModel,
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
|
||||
"""If under pipeline parallel setting, replacing the original forward method of huggingface
|
||||
to customized forward method, and add this changing to policy."""
|
||||
if self.pipeline_stage_manager is None:
|
||||
return
|
||||
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if self.model.__class__.__name__ == "CohereModel":
|
||||
module = self.model
|
||||
else:
|
||||
module = self.model.model
|
||||
|
||||
if stage_manager.is_interleave:
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage)
|
||||
method_replacement = {
|
||||
"forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config)
|
||||
}
|
||||
|
||||
else:
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
stage_index = stage_manager.get_stage_index(layers_per_stage)
|
||||
method_replacement = {
|
||||
"forward": partial(
|
||||
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
|
||||
)
|
||||
}
|
||||
|
||||
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
assert self.pipeline_stage_manager is not None
|
||||
|
||||
if self.model.__class__.__name__ == "CohereModel":
|
||||
module = self.model
|
||||
else:
|
||||
module = self.model.model
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
held_layers = []
|
||||
if stage_manager.is_interleave:
|
||||
assert stage_manager.num_model_chunks is not None
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
stage_indices = stage_manager.get_stage_index(layers_per_stage)
|
||||
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||
held_layers.append(module.embed_tokens)
|
||||
for start_idx, end_idx in stage_indices:
|
||||
held_layers.extend(module.layers[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(module.norm)
|
||||
|
||||
else:
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.embed_tokens)
|
||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||
held_layers.extend(module.layers[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.norm)
|
||||
|
||||
return held_layers
|
||||
|
||||
|
||||
class CommandModelPolicy(CommandPolicy):
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
from transformers.models.cohere.modeling_cohere import CohereModel
|
||||
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
self.set_pipeline_forward(
|
||||
model_cls=CohereModel, new_forward=CommandPipelineForwards.command_model_forward, policy=policy
|
||||
)
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
held_layers = super().get_held_layers()
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
"""No shared params in command model"""
|
||||
return []
|
||||
|
||||
|
||||
class CommandForCausalLMPolicy(CommandPolicy):
|
||||
def module_policy(self):
|
||||
from transformers import CohereForCausalLM
|
||||
|
||||
policy = super().module_policy()
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for casual lm
|
||||
new_item = {
|
||||
CohereForCausalLM: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=VocabParallelLMHead1D,
|
||||
kwargs={
|
||||
"gather_output": not self.shard_config.parallel_output,
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
}
|
||||
if self.shard_config.parallel_output:
|
||||
new_item[CohereForCausalLM].method_replacement = {
|
||||
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
|
||||
}
|
||||
else:
|
||||
new_item = {
|
||||
CohereForCausalLM: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=PaddingLMHead,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
)
|
||||
],
|
||||
)
|
||||
}
|
||||
policy.update(new_item)
|
||||
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
self.set_pipeline_forward(
|
||||
model_cls=CohereForCausalLM,
|
||||
new_forward=CommandPipelineForwards.command_for_causal_lm_forward,
|
||||
policy=policy,
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.lm_head)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
command_model = self.model.model
|
||||
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
|
||||
if (
|
||||
id(command_model.embed_tokens.weight) == id(self.model.lm_head.weight)
|
||||
and self.pipeline_stage_manager.num_stages > 1
|
||||
):
|
||||
# tie weights
|
||||
return [
|
||||
{
|
||||
0: command_model.embed_tokens.weight,
|
||||
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
|
||||
}
|
||||
]
|
||||
return []
|
|
@ -65,7 +65,7 @@ class GPT2Policy(Policy):
|
|||
else:
|
||||
norm_cls = col_nn.LayerNorm
|
||||
|
||||
sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None
|
||||
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
||||
assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for GPT2"
|
||||
if sp_mode == "ring":
|
||||
warnings.warn(
|
||||
|
|
|
@ -34,15 +34,11 @@ class GPTJPolicy(Policy):
|
|||
return self.model
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJBlock, GPTJModel
|
||||
|
||||
ATTN_IMPLEMENTATION = {
|
||||
"eager": GPTJAttention,
|
||||
}
|
||||
from transformers.models.gptj.modeling_gptj import GPTJ_ATTENTION_CLASSES, GPTJBlock, GPTJModel
|
||||
|
||||
policy = {}
|
||||
|
||||
attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
|
||||
attn_cls = GPTJ_ATTENTION_CLASSES[self.origin_attn_implement]
|
||||
|
||||
embedding_cls = None
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
|
|
|
@ -20,9 +20,7 @@ from colossalai.shardformer.layer import (
|
|||
from ..modeling.llama import (
|
||||
LlamaPipelineForwards,
|
||||
get_llama_flash_attention_forward,
|
||||
get_llama_model_forward_for_flash_attn,
|
||||
get_llama_seq_parallel_attention_forward,
|
||||
get_llama_seq_parallel_model_forward,
|
||||
get_llama_flash_attention_model_forward,
|
||||
get_lm_forward_with_dist_cross_entropy,
|
||||
)
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
@ -75,40 +73,12 @@ class LlamaPolicy(Policy):
|
|||
warnings.warn(
|
||||
f"For llama, sequence parallelism is currently not compatible with pipeline parallelism, set to be False"
|
||||
)
|
||||
sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None
|
||||
sp_size = self.shard_config.sequence_parallel_size if self.shard_config.enable_sequence_parallelism else None
|
||||
sp_group = (
|
||||
self.shard_config.sequence_parallel_process_group if self.shard_config.enable_sequence_parallelism else None
|
||||
)
|
||||
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
||||
sp_size = self.shard_config.sequence_parallel_size or None
|
||||
sp_group = self.shard_config.sequence_parallel_process_group or None
|
||||
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
||||
|
||||
use_flash_attention = self.shard_config.enable_flash_attention
|
||||
# Currently sp cannot to be used with flashattention
|
||||
if sp_mode in ["split_gather", "ring", "all_to_all"]:
|
||||
if use_flash_attention:
|
||||
warnings.warn(
|
||||
f"Sequence parallelism mode {sp_mode} need to be used with FlashAttention, will disable FlashAttention automatically."
|
||||
)
|
||||
use_flash_attention = False
|
||||
|
||||
if sp_mode in ["split_gather", "ring"]:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_llama_seq_parallel_model_forward(
|
||||
sp_mode=sp_mode, sp_size=sp_size, sp_group=sp_group
|
||||
),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=LlamaModel,
|
||||
)
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=attn_cls,
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
if sp_mode == "all_to_all":
|
||||
decoder_attribute_replacement = {
|
||||
"num_heads": self.model.config.num_attention_heads // sp_size,
|
||||
}
|
||||
|
@ -118,24 +88,27 @@ class LlamaPolicy(Policy):
|
|||
policy[attn_cls] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
)
|
||||
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group),
|
||||
"forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=attn_cls,
|
||||
)
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_llama_seq_parallel_model_forward(
|
||||
sp_mode=sp_mode,
|
||||
sp_size=sp_size,
|
||||
sp_group=sp_group,
|
||||
),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=LlamaModel,
|
||||
)
|
||||
if self.pipeline_stage_manager is None:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_llama_flash_attention_model_forward(
|
||||
self.shard_config,
|
||||
sp_mode=sp_mode,
|
||||
sp_size=sp_size,
|
||||
sp_group=sp_group,
|
||||
),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=LlamaModel,
|
||||
)
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
assert (
|
||||
|
@ -235,25 +208,6 @@ class LlamaPolicy(Policy):
|
|||
target_key=LlamaModel,
|
||||
)
|
||||
|
||||
# use flash attention
|
||||
if use_flash_attention:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_group, sp_size),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=attn_cls,
|
||||
)
|
||||
if self.pipeline_stage_manager is None:
|
||||
# replace llama model forward method
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_llama_model_forward_for_flash_attn(self.shard_config),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=LlamaModel,
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
|
|
|
@ -42,11 +42,13 @@ class MistralPolicy(Policy):
|
|||
MistralDecoderLayer,
|
||||
MistralFlashAttention2,
|
||||
MistralModel,
|
||||
MistralSdpaAttention,
|
||||
)
|
||||
|
||||
ATTN_IMPLEMENTATION = {
|
||||
"eager": MistralAttention,
|
||||
"flash_attention_2": MistralFlashAttention2,
|
||||
"sdpa": MistralSdpaAttention,
|
||||
}
|
||||
|
||||
policy = {}
|
||||
|
|
|
@ -25,6 +25,7 @@ class ChunkManager:
|
|||
chunk_configuration,
|
||||
init_device: Optional[torch.device] = None,
|
||||
reuse_fp16_chunk: bool = True,
|
||||
max_prefetch: int = 0,
|
||||
) -> None:
|
||||
self.device = init_device or get_accelerator().get_current_device()
|
||||
self.dp_degree_chunk_size_dict: Dict[int, int] = dict()
|
||||
|
@ -42,6 +43,7 @@ class ChunkManager:
|
|||
# Whether model is accumulating gradients,
|
||||
self.accumulating_grads = False
|
||||
self.overflow_counter = torch.tensor([0], dtype=torch.int, device=get_accelerator().get_current_device())
|
||||
self._prefetch_stream = get_accelerator().Stream() if max_prefetch else None
|
||||
|
||||
def register_tensor(
|
||||
self,
|
||||
|
|
|
@ -21,6 +21,7 @@ def init_chunk_manager(
|
|||
hidden_dim: Optional[int] = None,
|
||||
reuse_fp16_chunk: bool = True,
|
||||
verbose: bool = False,
|
||||
max_prefetch: int = 0,
|
||||
**kwargs,
|
||||
) -> ChunkManager:
|
||||
if hidden_dim:
|
||||
|
@ -51,9 +52,5 @@ def init_chunk_manager(
|
|||
)
|
||||
dist.barrier()
|
||||
|
||||
chunk_manager = ChunkManager(
|
||||
config_dict,
|
||||
init_device,
|
||||
reuse_fp16_chunk=reuse_fp16_chunk,
|
||||
)
|
||||
chunk_manager = ChunkManager(config_dict, init_device, reuse_fp16_chunk=reuse_fp16_chunk, max_prefetch=max_prefetch)
|
||||
return chunk_manager
|
||||
|
|
|
@ -104,9 +104,7 @@ class GeminiDDP(ModelWrapper):
|
|||
self.enable_gradient_accumulation = enable_gradient_accumulation
|
||||
if chunk_config_dict is not None:
|
||||
self.chunk_manager = ChunkManager(
|
||||
chunk_config_dict,
|
||||
chunk_init_device,
|
||||
reuse_fp16_chunk=reuse_fp16_chunk,
|
||||
chunk_config_dict, chunk_init_device, reuse_fp16_chunk=reuse_fp16_chunk, max_prefetch=max_prefetch
|
||||
)
|
||||
else:
|
||||
# some ugly hotfix for the compatibility with Lightning
|
||||
|
@ -122,6 +120,7 @@ class GeminiDDP(ModelWrapper):
|
|||
process_group=zero_group,
|
||||
reuse_fp16_chunk=reuse_fp16_chunk,
|
||||
verbose=verbose,
|
||||
max_prefetch=max_prefetch,
|
||||
)
|
||||
self.gemini_manager = GeminiManager(
|
||||
placement_policy,
|
||||
|
@ -451,6 +450,7 @@ class GeminiDDP(ModelWrapper):
|
|||
chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True)
|
||||
if not (master_weights) or (enable_gradient_accumulation):
|
||||
chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True)
|
||||
return empty_grad
|
||||
|
||||
def zero_grad(self, set_to_none: bool = False) -> None:
|
||||
self.module.zero_grad(set_to_none=True)
|
||||
|
|
|
@ -5,6 +5,7 @@ from typing import List
|
|||
|
||||
import torch
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHook
|
||||
from colossalai.utils import is_ddp_ignored
|
||||
from colossalai.zero.gemini import TensorState
|
||||
|
@ -54,10 +55,20 @@ class GeminiZeROHook(ColoParamOpHook):
|
|||
)
|
||||
|
||||
# prefetch
|
||||
for chunk in chunks_fetch_async:
|
||||
maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True)
|
||||
if maybe_work is not None:
|
||||
self._gemini_manager.add_work(chunk, maybe_work)
|
||||
if self._gemini_manager.chunk_manager._prefetch_stream is not None:
|
||||
# This is when prefetch happens the first time and there is no dist.Work to sync,
|
||||
# there is possibility that the optimizer haven't finish computation on default stream,
|
||||
# thus we might prefetch outdated chunks there.
|
||||
#
|
||||
# Other than that, self._gemini_manager.wait_chunks will have synced with default stream
|
||||
# by calling dist.Work.wait() and this line makes no diff.
|
||||
self._gemini_manager.chunk_manager._prefetch_stream.wait_stream(torch.cuda.current_stream())
|
||||
|
||||
with get_accelerator().stream(self._gemini_manager.chunk_manager._prefetch_stream):
|
||||
for chunk in chunks_fetch_async:
|
||||
maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True)
|
||||
if maybe_work is not None:
|
||||
self._gemini_manager.add_work(chunk, maybe_work)
|
||||
|
||||
# record cuda model data of the current OP, including memory for prefetched chunks
|
||||
self._gemini_manager.record_model_data_volume()
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
</div>
|
||||
|
||||
## 新闻
|
||||
* [2024/06] [Open-Sora Continues Open Source: Generate Any 16-Second 720p HD Video with One Click, Model Weights Ready to Use](https://hpc-ai.com/blog/open-sora-from-hpc-ai-tech-team-continues-open-source-generate-any-16-second-720p-hd-video-with-one-click-model-weights-ready-to-use)
|
||||
* [2024/05] [Large AI Models Inference Speed Doubled, Colossal-Inference Open Source Release](https://hpc-ai.com/blog/colossal-inference)
|
||||
* [2024/04] [Open-Sora Unveils Major Upgrade: Embracing Open Source with Single-Shot 16-Second Video Generation and 720p Resolution](https://hpc-ai.com/blog/open-soras-comprehensive-upgrade-unveiled-embracing-16-second-video-generation-and-720p-resolution-in-open-source)
|
||||
* [2024/04] [Most cost-effective solutions for inference, fine-tuning and pretraining, tailored to LLaMA3 series](https://hpc-ai.com/blog/most-cost-effective-solutions-for-inference-fine-tuning-and-pretraining-tailored-to-llama3-series)
|
||||
|
@ -31,10 +32,6 @@
|
|||
* [2024/03] [Open-Sora: Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models](https://hpc-ai.com/blog/open-sora-v1.0)
|
||||
* [2024/03] [Open-Sora:Sora Replication Solution with 46% Cost Reduction, Sequence Expansion to Nearly a Million](https://hpc-ai.com/blog/open-sora)
|
||||
* [2024/01] [Inference Performance Improved by 46%, Open Source Solution Breaks the Length Limit of LLM for Multi-Round Conversations](https://hpc-ai.com/blog/Colossal-AI-SwiftInfer)
|
||||
* [2024/01] [Construct Refined 13B Private Model With Just $5000 USD, Upgraded Colossal-AI Llama-2 Open Source](https://hpc-ai.com/blog/colossal-llama-2-13b)
|
||||
* [2023/11] [Enhanced MoE Parallelism, Open-source MoE Model Training Can Be 9 Times More Efficient](https://www.hpc-ai.tech/blog/enhanced-moe-parallelism-open-source-moe-model-training-can-be-9-times-more-efficient)
|
||||
* [2023/09] [One Half-Day of Training Using a Few Hundred Dollars Yields Similar Results to Mainstream Large Models, Open-Source and Commercial-Free Domain-Specific LLM Solution](https://www.hpc-ai.tech/blog/one-half-day-of-training-using-a-few-hundred-dollars-yields-similar-results-to-mainstream-large-models-open-source-and-commercial-free-domain-specific-llm-solution)
|
||||
* [2023/09] [70 Billion Parameter LLaMA2 Model Training Accelerated by 195%](https://www.hpc-ai.tech/blog/70b-llama2-training)
|
||||
* [2023/07] [HPC-AI Tech Raises 22 Million USD in Series A Funding](https://www.hpc-ai.tech/blog/hpc-ai-tech-raises-22-million-usd-in-series-a-funding-to-fuel-team-expansion-and-business-growth)
|
||||
|
||||
## 目录
|
||||
|
@ -127,13 +124,13 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
|
|||
|
||||
[Open-Sora](https://github.com/hpcaitech/Open-Sora):全面开源类Sora模型参数和所有训练细节
|
||||
[[代码]](https://github.com/hpcaitech/Open-Sora)
|
||||
[[博客]](https://hpc-ai.com/blog/open-soras-comprehensive-upgrade-unveiled-embracing-16-second-video-generation-and-720p-resolution-in-open-source)
|
||||
[[模型权重]](https://huggingface.co/hpcai-tech/Open-Sora)
|
||||
[[博客]](https://hpc-ai.com/blog/open-sora-from-hpc-ai-tech-team-continues-open-source-generate-any-16-second-720p-hd-video-with-one-click-model-weights-ready-to-use)
|
||||
[[模型权重]](https://github.com/hpcaitech/Open-Sora?tab=readme-ov-file#model-weights)
|
||||
[[演示样例]](https://github.com/hpcaitech/Open-Sora?tab=readme-ov-file#-latest-demo)
|
||||
|
||||
<div align="center">
|
||||
<a href="https://www.bilibili.com/video/BV1dW421c7MN">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/sora/sora-demo-cn.png" width="700" />
|
||||
<a href="https://www.bilibili.com/video/BV1Fm421G7bV">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/sora/opensora-v1.2.png" width="700" />
|
||||
</a>
|
||||
</div>
|
||||
|
||||
|
|
|
@ -56,6 +56,7 @@
|
|||
"features/pipeline_parallel",
|
||||
"features/nvme_offload",
|
||||
"features/lazy_init",
|
||||
"features/distributed_optimizers",
|
||||
"features/cluster_utils"
|
||||
]
|
||||
},
|
||||
|
|
|
@ -4,9 +4,9 @@ Author: [Wenxuan Tan](https://github.com/Edenzzzz), [Junwen Duan](https://github
|
|||
|
||||
**Related Paper**
|
||||
- [Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235)
|
||||
- [CAME: Confidence-guided Adaptive Memory Efficient Optimization] (https://arxiv.org/abs/2307.02047)
|
||||
- [GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection] (https://arxiv.org/abs/2403.03507)
|
||||
- [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes] (https://arxiv.org/pdf/1904.00962)
|
||||
- [CAME: Confidence-guided Adaptive Memory Efficient Optimization](https://arxiv.org/abs/2307.02047)
|
||||
- [GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection](https://arxiv.org/abs/2403.03507)
|
||||
- [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes](https://arxiv.org/pdf/1904.00962)
|
||||
|
||||
## Introduction
|
||||
Apart from the widely adopted Adam and SGD, many modern optimizers require layer-wise statistics to update parameters, and thus aren't directly applicable to settings where model layers are sharded across multiple devices. We provide optimized distributed implementations with minimal extra communications, and seamless integrations with Tensor Parallel, DDP and ZeRO plugins, which automatically uses distributed optimizers with 0 code change.
|
||||
|
@ -14,12 +14,6 @@ Apart from the widely adopted Adam and SGD, many modern optimizers require layer
|
|||
## Optimizers
|
||||
Adafactor is a first-order Adam variant using Non-negative Matrix Factorization(NMF) to reduce memory footprint. CAME improves by introducting a confidence matrix to correct NMF. GaLore further reduces memory by projecting gradients into a low-rank space and 8-bit block-wise quantization. Lamb allows huge batch sizes without lossing accuracy via layer-wise adaptive update bounded by the inverse of its Lipschiz constant.
|
||||
|
||||
## API Reference
|
||||
|
||||
{{ autodoc:colossalai.nn.optimizer.distributed_adafactor.DistributedAdaFactor }}
|
||||
{{ autodoc:colossalai.nn.optimizer.distributed_lamb.DistributedLamb }}
|
||||
{{ autodoc:colossalai.nn.optimizer.distributed_galore.DistGaloreAwamW }}
|
||||
{{ autodoc:colossalai.nn.optimizer.distributed_came.DistributedCAME }}
|
||||
|
||||
## Hands-On Practice
|
||||
We now demonstrate how to use Distributed Adafactor with booster API combining Tensor Parallel and ZeRO 2 with 4 GPUs. **Note that even if you're not aware of distributed optimizers, the plugins automatically casts yours to the distributed version for convenience.**
|
||||
|
@ -140,3 +134,10 @@ optim = DistGaloreAwamW(
|
|||
</table>
|
||||
|
||||
<!-- doc-test-command: colossalai run --nproc_per_node 4 distributed_optimizers.py -->
|
||||
|
||||
## API Reference
|
||||
|
||||
{{ autodoc:colossalai.nn.optimizer.distributed_adafactor.DistributedAdaFactor }}
|
||||
{{ autodoc:colossalai.nn.optimizer.distributed_lamb.DistributedLamb }}
|
||||
{{ autodoc:colossalai.nn.optimizer.distributed_galore.DistGaloreAwamW }}
|
||||
{{ autodoc:colossalai.nn.optimizer.distributed_came.DistributedCAME }}
|
||||
|
|
|
@ -4,21 +4,15 @@ Author: Wenxuan Tan, Junwen Duan, Renjie Mao
|
|||
|
||||
**相关论文**
|
||||
- [Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235)
|
||||
- [CAME: Confidence-guided Adaptive Memory Efficient Optimization] (https://arxiv.org/abs/2307.02047)
|
||||
- [GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection] (https://arxiv.org/abs/2403.03507)
|
||||
- [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes] (https://arxiv.org/pdf/1904.00962)
|
||||
- [CAME: Confidence-guided Adaptive Memory Efficient Optimization](https://arxiv.org/abs/2307.02047)
|
||||
- [GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection](https://arxiv.org/abs/2403.03507)
|
||||
- [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes](https://arxiv.org/pdf/1904.00962)
|
||||
|
||||
## 介绍
|
||||
除了广泛采用的Adam和SGD外,许多现代优化器需要逐层统计信息以有效更新参数,因此无法直接应用于模型层在多个设备上分片的并行设置。我们以提供了优化的分布式实现,,并且通过plugin与Tensor Parallel、DDP和ZeRO无缝集成。
|
||||
## 优化器
|
||||
Adafactor 是一种首次采用非负矩阵分解(NMF)的 Adam 变体,用于减少内存占用。CAME 通过引入一个置信度矩阵来改进 NMF 的效果。GaLore 通过将梯度投影到低秩空间,并使用 8 位块状量化进一步减少内存占用。Lamb 允许使用巨大的批量大小而不失准确性,通过按其 Lipschitz 常数的倒数界定的逐层自适应更新实现
|
||||
|
||||
## API 参考
|
||||
|
||||
{{ autodoc:colossalai.nn.optimizer.distributed_adafactor.DistributedAdaFactor }}
|
||||
{{ autodoc:colossalai.nn.optimizer.distributed_lamb.DistributedLamb }}
|
||||
{{ autodoc:colossalai.nn.optimizer.distributed_galore.DistGaloreAwamW }}
|
||||
{{ autodoc:colossalai.nn.optimizer.distributed_came.DistributedCAME }}
|
||||
|
||||
## 使用
|
||||
现在我们展示如何使用分布式 Adafactor 与 booster API 结合 Tensor Parallel 和 ZeRO 2。即使您不使用distributed optimizer,plugin 也会自动将optimizer转换为分布式版本以方便使用。
|
||||
|
@ -137,3 +131,10 @@ optim = DistGaloreAwamW(
|
|||
</table>
|
||||
|
||||
<!-- doc-test-command: colossalai run --nproc_per_node 4 distributed_optimizers.py -->
|
||||
|
||||
## API 参考
|
||||
|
||||
{{ autodoc:colossalai.nn.optimizer.distributed_adafactor.DistributedAdaFactor }}
|
||||
{{ autodoc:colossalai.nn.optimizer.distributed_lamb.DistributedLamb }}
|
||||
{{ autodoc:colossalai.nn.optimizer.distributed_galore.DistGaloreAwamW }}
|
||||
{{ autodoc:colossalai.nn.optimizer.distributed_came.DistributedCAME }}
|
||||
|
|
|
@ -43,5 +43,8 @@ colossalai run --nproc_per_node 2 llama_generation.py -m PATH_MODEL --drafter_mo
|
|||
|
||||
If you want to try the GLIDE model (glide-vicuna7b) as the drafter model with vicuna-7B, you could provide the GLIDE model path or model card as drafter model and enable the feature by
|
||||
```python
|
||||
from colossalai.inference.modeling.models.glide_llama import GlideLlamaForCausalLM
|
||||
drafter_model = GlideLlamaForCausalLM.from_pretrained(drafter_model_path_or_name)
|
||||
...
|
||||
engine.enable_spec_dec(drafter_model, use_glide_drafter=True)
|
||||
```
|
||||
|
|
|
@ -72,6 +72,7 @@ def main():
|
|||
parser.add_argument("--offload_optim_frac", type=float, default=0.0, help="Offload optim fraction. Only for gemini")
|
||||
parser.add_argument("--offload_param_frac", type=float, default=0.0, help="Offload param fraction. Only for gemini")
|
||||
parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size")
|
||||
parser.add_argument("--sp", type=int, default=1, help="Sequence parallel size")
|
||||
parser.add_argument("--extra_dp", type=int, default=1, help="Extra data parallel size, used for Gemini")
|
||||
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size")
|
||||
parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel")
|
||||
|
@ -174,6 +175,8 @@ def main():
|
|||
tp_size=args.tp,
|
||||
pp_size=args.pp,
|
||||
zero_stage=args.zero,
|
||||
sp_size=args.sp,
|
||||
enable_sequence_parallelism=args.sp > 1,
|
||||
enable_fused_normalization=torch.cuda.is_available(),
|
||||
enable_flash_attention=args.xformers,
|
||||
microbatch_size=args.mbs,
|
||||
|
|
|
@ -16,7 +16,7 @@ ray
|
|||
sentencepiece
|
||||
google
|
||||
protobuf
|
||||
transformers>=4.36.2,<4.40.0
|
||||
transformers==4.39.3
|
||||
peft>=0.7.1
|
||||
bitsandbytes>=0.39.0
|
||||
rpyc==6.0.0
|
||||
|
|
|
@ -22,3 +22,9 @@ try:
|
|||
from .qwen2 import *
|
||||
except ImportError:
|
||||
print("This version of transformers doesn't support qwen2.")
|
||||
|
||||
|
||||
try:
|
||||
from .command import *
|
||||
except ImportError:
|
||||
print("This version of transformers doesn't support Command-R.")
|
||||
|
|
|
@ -33,22 +33,6 @@ loss_fn_for_chatglm_model = lambda x: torch.nn.functional.mse_loss(
|
|||
)
|
||||
loss_fn = lambda x: x["loss"]
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
"THUDM/chatglm2-6b",
|
||||
trust_remote_code=True,
|
||||
num_layers=2,
|
||||
padded_vocab_size=65024,
|
||||
hidden_size=64,
|
||||
ffn_hidden_size=214,
|
||||
num_attention_heads=8,
|
||||
kv_channels=16,
|
||||
rmsnorm=True,
|
||||
original_rope=True,
|
||||
use_cache=True,
|
||||
multi_query_attention=False,
|
||||
torch_dtype=torch.float32,
|
||||
)
|
||||
|
||||
|
||||
infer_config = AutoConfig.from_pretrained(
|
||||
"THUDM/chatglm2-6b",
|
||||
|
@ -68,6 +52,21 @@ infer_config = AutoConfig.from_pretrained(
|
|||
|
||||
|
||||
def init_chatglm():
|
||||
config = AutoConfig.from_pretrained(
|
||||
"THUDM/chatglm2-6b",
|
||||
trust_remote_code=True,
|
||||
num_layers=2,
|
||||
padded_vocab_size=65024,
|
||||
hidden_size=64,
|
||||
ffn_hidden_size=214,
|
||||
num_attention_heads=8,
|
||||
kv_channels=16,
|
||||
rmsnorm=True,
|
||||
original_rope=True,
|
||||
use_cache=True,
|
||||
multi_query_attention=False,
|
||||
torch_dtype=torch.float32,
|
||||
)
|
||||
model = AutoModelForCausalLM.from_config(config, empty_init=False, trust_remote_code=True)
|
||||
for m in model.modules():
|
||||
if m.__class__.__name__ == "RMSNorm":
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
import torch
|
||||
import transformers
|
||||
|
||||
from ..registry import ModelAttribute, model_zoo
|
||||
|
||||
try:
|
||||
from transformers import CohereConfig
|
||||
|
||||
HAS_COMMAND = True
|
||||
except ImportError:
|
||||
HAS_COMMAND = False
|
||||
|
||||
if HAS_COMMAND:
|
||||
# ===============================
|
||||
# Register Command-R
|
||||
# ===============================
|
||||
|
||||
def data_gen():
|
||||
input_ids = torch.Tensor(
|
||||
[
|
||||
[1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082],
|
||||
[1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082],
|
||||
]
|
||||
).long()
|
||||
|
||||
attention_mask = torch.Tensor(
|
||||
[
|
||||
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||
]
|
||||
).long()
|
||||
|
||||
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
# label is needed for casual lm
|
||||
def data_gen_for_casual_lm():
|
||||
data = data_gen()
|
||||
labels = data["input_ids"].clone()
|
||||
data["labels"] = labels
|
||||
return data
|
||||
|
||||
# transform the output to a dict
|
||||
output_transform_fn = lambda x: x
|
||||
|
||||
# function to get the loss
|
||||
loss_fn = lambda output: output["last_hidden_state"].mean()
|
||||
loss_fn_for_casual_lm = lambda output: output["loss"]
|
||||
loss_fn_for_seq_classification = lambda output: output["logits"].mean()
|
||||
|
||||
config = CohereConfig(
|
||||
num_hidden_layers=8,
|
||||
hidden_size=32,
|
||||
intermediate_size=64,
|
||||
num_attention_heads=4,
|
||||
max_position_embeddings=128,
|
||||
)
|
||||
|
||||
if hasattr(config, "pad_token_id"):
|
||||
config.pad_token_id = config.eos_token_id
|
||||
|
||||
# register the following models
|
||||
# transformers.CohereModel,
|
||||
# transformers.CohereForCausalLM,
|
||||
model_zoo.register(
|
||||
name="transformers_command",
|
||||
model_fn=lambda: transformers.CohereModel(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_command_for_casual_lm",
|
||||
model_fn=lambda: transformers.CohereForCausalLM(config),
|
||||
data_gen_fn=data_gen_for_casual_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_casual_lm,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
|
@ -4,7 +4,7 @@ import numpy as np
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes
|
||||
from colossalai.inference.utils import get_alibi_slopes
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.utils import get_current_device
|
||||
from tests.test_infer.test_kernels.triton.test_context_attn_unpad import generate_alibi_mask
|
||||
|
|
|
@ -26,7 +26,7 @@ def prepare_data(
|
|||
num_tokens = torch.sum(context_lengths).item()
|
||||
|
||||
max_seq_len_in_batch = context_lengths.max()
|
||||
cu_seqlens = F.pad(torch.cumsum(context_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
|
||||
cu_seqlens = F.pad(torch.cumsum(context_lengths, dim=0, dtype=torch.int32), (1, 0))
|
||||
|
||||
kv_size = (num_tokens, num_kv_heads, HEAD_DIM)
|
||||
key = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
||||
|
|
|
@ -28,15 +28,22 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, K_H, D, dtype):
|
|||
torch.manual_seed(10)
|
||||
TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN
|
||||
# our crafted op equals to Transformers
|
||||
x0 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D, dtype=dtype)
|
||||
x1 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D, dtype=dtype)
|
||||
x0 = torch.randn(BATCH_SIZE, H, SEQ_LEN, D, dtype=dtype)
|
||||
x1 = torch.randn(BATCH_SIZE, H, SEQ_LEN, D, dtype=dtype)
|
||||
|
||||
position_ids = torch.arange(TOTAL_TOKENS).reshape((BATCH_SIZE, SEQ_LEN))
|
||||
|
||||
emb = LlamaRotaryEmbedding(D)
|
||||
cos, sin = emb(x0, TOTAL_TOKENS)
|
||||
|
||||
cos, sin = emb(x0, position_ids)
|
||||
embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin)
|
||||
cos = cos.reshape((TOTAL_TOKENS, -1))
|
||||
sin = sin.reshape((TOTAL_TOKENS, -1))
|
||||
cos_2 = cos[:, : D // 2]
|
||||
sin_2 = sin[:, : D // 2]
|
||||
position_ids = torch.arange(TOTAL_TOKENS)
|
||||
embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin, position_ids)
|
||||
embd_stimulated_x = torch_rotary_emb(x0, cos_2, sin_2)
|
||||
x2 = x0.transpose(1, 2).reshape(TOTAL_TOKENS, H, D)
|
||||
embd_stimulated_x = torch_rotary_emb(x2, cos_2, sin_2)
|
||||
embd_stimulated_x = embd_stimulated_x.reshape((BATCH_SIZE, SEQ_LEN, H, D)).transpose(1, 2)
|
||||
assert torch.allclose(embd_x0, embd_stimulated_x)
|
||||
|
||||
# create data
|
||||
|
|
|
@ -2,7 +2,7 @@ import pytest
|
|||
import torch
|
||||
from packaging import version
|
||||
|
||||
from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes
|
||||
from colossalai.inference.utils import get_alibi_slopes
|
||||
from colossalai.kernel.triton import context_attention_unpadded
|
||||
from colossalai.utils import get_current_device
|
||||
from tests.test_infer.test_kernels.triton.kernel_utils import (
|
||||
|
|
|
@ -3,7 +3,7 @@ import pytest
|
|||
import torch
|
||||
from packaging import version
|
||||
|
||||
from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes
|
||||
from colossalai.inference.utils import get_alibi_slopes
|
||||
from colossalai.kernel.triton import flash_decoding_attention
|
||||
from colossalai.utils import get_current_device
|
||||
from tests.test_infer.test_kernels.triton.kernel_utils import (
|
||||
|
|
|
@ -43,15 +43,19 @@ def torch_rotary_emb(x, cos, sin):
|
|||
def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype, use_new_kcache_layout):
|
||||
TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN
|
||||
# our crafted op equals to Transformers
|
||||
x0 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D)
|
||||
x1 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D)
|
||||
x0 = torch.randn(BATCH_SIZE, H, SEQ_LEN, D, dtype=dtype)
|
||||
x1 = torch.randn(BATCH_SIZE, H, SEQ_LEN, D, dtype=dtype)
|
||||
emb = LlamaRotaryEmbedding(D)
|
||||
cos, sin = emb(x0, TOTAL_TOKENS)
|
||||
position_ids = torch.arange(TOTAL_TOKENS).reshape((BATCH_SIZE, SEQ_LEN))
|
||||
cos, sin = emb(x0, position_ids)
|
||||
embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin)
|
||||
cos = cos.reshape((TOTAL_TOKENS, -1))
|
||||
sin = sin.reshape((TOTAL_TOKENS, -1))
|
||||
cos_2 = cos[:, :32]
|
||||
sin_2 = sin[:, :32]
|
||||
position_ids = torch.arange(TOTAL_TOKENS)
|
||||
embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin, position_ids)
|
||||
embd_stimulated_x = torch_rotary_emb(x0, cos_2, sin_2)
|
||||
x2 = x0.transpose(1, 2).reshape(TOTAL_TOKENS, H, D)
|
||||
embd_stimulated_x = torch_rotary_emb(x2, cos_2, sin_2)
|
||||
embd_stimulated_x = embd_stimulated_x.reshape((BATCH_SIZE, SEQ_LEN, H, D)).transpose(1, 2)
|
||||
assert torch.allclose(embd_x0, embd_stimulated_x)
|
||||
|
||||
# create data
|
||||
|
|
|
@ -55,7 +55,7 @@ def check_inference_engine(use_engine=False, do_sample=False, use_cuda_kernel=Fa
|
|||
assert inference_engine.generation_config.max_new_tokens == output_len
|
||||
inference_engine.add_request(prompts=inputs)
|
||||
assert inference_engine.request_handler._has_waiting()
|
||||
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
|
||||
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k, max_new_tokens=output_len)
|
||||
outputs = inference_engine.generate(generation_config=generation_config)
|
||||
else:
|
||||
if prompt_template:
|
||||
|
|
|
@ -0,0 +1,161 @@
|
|||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.multiprocessing import Manager
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
import colossalai
|
||||
import colossalai.inference.modeling.policy as policy
|
||||
from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig
|
||||
from colossalai.inference.core.engine import InferenceEngine
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
# NOTE: To test a model with the inference engine, you need to provide the path to your
|
||||
# local pretrained model weights in the MODEL_MAP dictionary
|
||||
MODEL_MAP = {
|
||||
"baichuan": {
|
||||
"model": AutoModelForCausalLM,
|
||||
"tokenizer": AutoTokenizer,
|
||||
"policy": policy.NoPaddingBaichuanModelInferPolicy,
|
||||
"model_name_or_path": "baichuan-inc/Baichuan2-13B-Base", # provide the path to local model weights
|
||||
},
|
||||
"llama": {
|
||||
"model": LlamaForCausalLM,
|
||||
"tokenizer": LlamaTokenizer,
|
||||
"policy": policy.NoPaddingLlamaModelInferPolicy,
|
||||
"model_name_or_path": "meta-llama/Llama-2-70b-hf",
|
||||
},
|
||||
}
|
||||
|
||||
MODELS_TO_TEST = ["llama", "baichuan"] # Specify the models to test
|
||||
|
||||
|
||||
@parameterize("model", MODELS_TO_TEST)
|
||||
@parameterize("prompt_template", [None, "model_specific"])
|
||||
@parameterize("do_sample", [False])
|
||||
@parameterize("use_cuda_kernel", [True])
|
||||
@pytest.mark.largedist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_model(model, prompt_template, do_sample, use_cuda_kernel):
|
||||
model_path = MODEL_MAP[model]["model_name_or_path"]
|
||||
if not os.path.exists(model_path):
|
||||
pytest.skip(
|
||||
f"There is no local model address included for {model}, please replace this address with a valid one."
|
||||
)
|
||||
|
||||
if prompt_template == "model_specific":
|
||||
prompt_template = model
|
||||
|
||||
model_config = MODEL_MAP[model]
|
||||
|
||||
kwargs1 = {
|
||||
"model": model,
|
||||
"use_engine": True,
|
||||
"prompt_template": prompt_template,
|
||||
"do_sample": do_sample,
|
||||
"policy": model_config["policy"](),
|
||||
"use_cuda_kernel": use_cuda_kernel,
|
||||
}
|
||||
|
||||
kwargs2 = {
|
||||
"model": model,
|
||||
"use_engine": False,
|
||||
"prompt_template": prompt_template,
|
||||
"do_sample": do_sample,
|
||||
"policy": None,
|
||||
"use_cuda_kernel": use_cuda_kernel,
|
||||
}
|
||||
|
||||
colossal_tp_1_output = run_engine(1, **kwargs1)
|
||||
colossal_tp_2_output = run_engine(2, **kwargs1)
|
||||
transformer_tp_1_output = run_engine(1, **kwargs2)
|
||||
|
||||
for s1, s2, s3 in zip(colossal_tp_1_output, colossal_tp_2_output, transformer_tp_1_output):
|
||||
assert s1 == s3, f"\nColossalAI TP=1 Output: {s1}\nTransformers Output: {s3}"
|
||||
assert s1 == s2, f"\nColossalAI TP=1 Output: {s1}\nColossalAI TP=2 Output: {s2}"
|
||||
|
||||
|
||||
def run_engine(world_size, **kwargs):
|
||||
manager = Manager()
|
||||
result_list = manager.list([-1] * world_size) # Create a shared list
|
||||
spawn(run_dist, world_size, func_to_run=_run_engine, ret=result_list, **kwargs)
|
||||
return result_list[0]
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs):
|
||||
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
|
||||
if ret:
|
||||
ret[rank] = func_to_run(**kwargs)
|
||||
else:
|
||||
func_to_run(**kwargs)
|
||||
|
||||
|
||||
def _run_engine(model, use_engine=False, do_sample=False, use_cuda_kernel=False, prompt_template=None, policy=None):
|
||||
setup_seed(20)
|
||||
model_config = MODEL_MAP[model]
|
||||
model_name_or_path = model_config["model_name_or_path"]
|
||||
tokenizer = model_config["tokenizer"].from_pretrained(model_name_or_path, use_fast=False, trust_remote_code=True)
|
||||
model = model_config["model"].from_pretrained(model_name_or_path, trust_remote_code=True).half().cuda()
|
||||
model = model.eval()
|
||||
|
||||
inputs = [
|
||||
"Introduce some landmarks in Paris:",
|
||||
]
|
||||
|
||||
output_len = 38
|
||||
|
||||
if do_sample:
|
||||
top_p = 0.5
|
||||
top_k = 50
|
||||
else:
|
||||
top_p = None
|
||||
top_k = None
|
||||
|
||||
if use_engine:
|
||||
inference_config = InferenceConfig(
|
||||
max_output_len=output_len,
|
||||
prompt_template=prompt_template,
|
||||
use_cuda_kernel=use_cuda_kernel,
|
||||
tp_size=dist.get_world_size(),
|
||||
)
|
||||
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True, model_policy=policy)
|
||||
assert inference_engine.generation_config.max_new_tokens == output_len
|
||||
inference_engine.add_request(prompts=inputs)
|
||||
assert inference_engine.request_handler._has_waiting()
|
||||
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k, max_new_tokens=output_len)
|
||||
outputs = inference_engine.generate(generation_config=generation_config)
|
||||
else:
|
||||
if prompt_template:
|
||||
# apply prompt template
|
||||
inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs]
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
|
||||
inputs = inputs.cuda()
|
||||
generation_config = GenerationConfig(
|
||||
do_sample=do_sample,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
max_new_tokens=output_len,
|
||||
)
|
||||
outputs = model.generate(inputs, generation_config=generation_config)
|
||||
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
return outputs
|
||||
|
||||
|
||||
def setup_seed(seed):
|
||||
torch.manual_seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_model()
|
|
@ -0,0 +1,322 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer import PipelineGradientCheckpointConfig
|
||||
from colossalai.shardformer.layer.utils import Randomizer
|
||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import (
|
||||
build_model_from_hybrid_plugin,
|
||||
check_all_grad_tensors,
|
||||
check_loss,
|
||||
check_output_hidden_state,
|
||||
check_weight,
|
||||
get_grad_tensors_for_check,
|
||||
run_forward_backward_with_hybrid_plugin,
|
||||
unwrap_model,
|
||||
)
|
||||
|
||||
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
||||
|
||||
|
||||
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
|
||||
enable_gradient_checkpointing = test_config.pop("enable_gradient_checkpointing", False)
|
||||
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
|
||||
model_fn, loss_fn, test_config
|
||||
)
|
||||
if enable_gradient_checkpointing:
|
||||
# org_model.gradient_checkpointing_enable()
|
||||
sharded_model.unwrap().gradient_checkpointing_enable()
|
||||
|
||||
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
|
||||
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
|
||||
)
|
||||
|
||||
stage_manager = booster.plugin.stage_manager
|
||||
tp_group = booster.plugin.tp_group
|
||||
|
||||
# unwrap model
|
||||
command_model = unwrap_model(org_model, "CohereModel", "model")
|
||||
shard_command_model = unwrap_model(sharded_model, "CohereModel", "model")
|
||||
|
||||
row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"]
|
||||
col_layer_for_check = ["layers[0].self_attn.o_proj"]
|
||||
# Here we check the grad of layernorm because an all-reduce operation should be performed during sequence parallelism
|
||||
norm_layer_for_check = ["layers[0].input_layernorm", "layers[1].input_layernorm"]
|
||||
|
||||
# During pipeline parallelism, we cannot get the grad of norm layer during first stage, so we only check this when pp is not enbaled
|
||||
if stage_manager is None:
|
||||
norm_layer_for_check.append("norm")
|
||||
|
||||
# Check the grad when using ZeRO-1 and ZeRO-2
|
||||
if (
|
||||
booster.plugin.zero_stage in [1, 2]
|
||||
and booster.plugin.shard_config.enable_sequence_parallelism
|
||||
and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all"
|
||||
):
|
||||
for p1, p2 in zip(command_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]):
|
||||
working_p = sharded_optimizer._param_store.master_to_working_param[id(p2)]
|
||||
grads = sharded_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(working_p))
|
||||
grad_index = (
|
||||
0 if sharded_optimizer._grad_store._partition_grads else sharded_optimizer._bucket_store.zero_local_rank
|
||||
)
|
||||
grad = grads[grad_index]
|
||||
sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()]
|
||||
assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False)
|
||||
|
||||
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
|
||||
grads_to_check = {}
|
||||
if (stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True)) and booster.plugin.zero_stage == 0:
|
||||
if test_config["precision"] == "fp32":
|
||||
atol, rtol = 1e-6, 1e-4
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
row_layer_grads = get_grad_tensors_for_check(
|
||||
command_model,
|
||||
shard_command_model,
|
||||
row_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=0,
|
||||
verbose=False,
|
||||
)
|
||||
col_layer_grads = get_grad_tensors_for_check(
|
||||
command_model,
|
||||
shard_command_model,
|
||||
col_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=1,
|
||||
verbose=False,
|
||||
)
|
||||
norm_layer_grads = get_grad_tensors_for_check(
|
||||
command_model,
|
||||
shard_command_model,
|
||||
norm_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=1,
|
||||
verbose=False,
|
||||
)
|
||||
grads_to_check.update(col_layer_grads)
|
||||
grads_to_check.update(row_layer_grads)
|
||||
grads_to_check.update(norm_layer_grads)
|
||||
|
||||
# optimizer executes step
|
||||
org_optimizer.step()
|
||||
sharded_optimizer.step()
|
||||
|
||||
# check last hidden state & loss
|
||||
if stage_manager is None or stage_manager.is_last_stage(ignore_chunk=True):
|
||||
if test_config["precision"] == "fp32":
|
||||
atol, rtol = 1e-5, 1e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
|
||||
if org_model.__class__.__name__ == "CohereModel":
|
||||
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
|
||||
|
||||
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||
|
||||
# check weights
|
||||
if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True):
|
||||
if test_config["precision"] == "fp32":
|
||||
atol, rtol = 5e-4, 1e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
check_weight(
|
||||
command_model,
|
||||
shard_command_model,
|
||||
col_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=1,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
# check grads
|
||||
check_all_grad_tensors(grads_to_check)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 1,
|
||||
"num_microbatches": 1,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "ring",
|
||||
"enable_flash_attention": True,
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 2,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 4,
|
||||
"pp_size": 1,
|
||||
"num_microbatches": 1,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "split_gather",
|
||||
"enable_flash_attention": False,
|
||||
"use_lazy_init": True,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 1,
|
||||
"pp_size": 1,
|
||||
"sp_size": 2,
|
||||
"num_microbatches": 1,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "all_to_all",
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 2,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 2,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": True,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
"enable_gradient_checkpointing": True,
|
||||
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(gradient_checkpointing_ratio=0.5),
|
||||
},
|
||||
{
|
||||
"tp_size": 1,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp32",
|
||||
"enable_gradient_checkpointing": True,
|
||||
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]),
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 1,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 2,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 1,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 2,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 1,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
],
|
||||
)
|
||||
def run_command_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm")
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
|
||||
clear_layout_converter()
|
||||
Randomizer.reset_index()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"enable_all_optimization": False,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp32",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"enable_all_optimization": False,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp16",
|
||||
"zero_stage": 1,
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"pp_style": "interleaved",
|
||||
"num_model_chunks": 2,
|
||||
"num_microbatches": 4,
|
||||
"enable_all_optimization": False,
|
||||
"precision": "fp16",
|
||||
"zero_stage": 1,
|
||||
"initial_scale": 1,
|
||||
"enable_gradient_checkpointing": True,
|
||||
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(
|
||||
num_ckpt_layers_per_stage=[0, 1, 2, 2],
|
||||
),
|
||||
},
|
||||
],
|
||||
)
|
||||
def run_command_3d_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm")
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
|
||||
clear_layout_converter()
|
||||
Randomizer.reset_index()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_command(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_command_test()
|
||||
|
||||
|
||||
def check_command_3d(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_command_3d_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_command():
|
||||
spawn(check_command, 4)
|
||||
|
||||
|
||||
@pytest.mark.largedist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_command_3d():
|
||||
spawn(check_command_3d, 8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_command()
|
||||
test_command_3d()
|
|
@ -120,9 +120,20 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
atol, rtol = 1e-4, 1e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
check_weight(
|
||||
llama_model, shard_llama_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
|
||||
)
|
||||
try:
|
||||
check_weight(
|
||||
llama_model,
|
||||
shard_llama_model,
|
||||
col_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=1,
|
||||
verbose=False,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed config: {test_config}")
|
||||
raise e
|
||||
|
||||
# check grads
|
||||
check_all_grad_tensors(grads_to_check)
|
||||
|
@ -133,9 +144,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
{
|
||||
{ # Test ring + Flash attention
|
||||
"tp_size": 2,
|
||||
"pp_size": 1,
|
||||
"sp_size": 2,
|
||||
"num_microbatches": 1,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "ring",
|
||||
|
@ -145,14 +157,16 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 4,
|
||||
"pp_size": 1,
|
||||
"num_microbatches": 1,
|
||||
{ # Ulysess + Flash attention
|
||||
"tp_size": 1,
|
||||
"pp_size": 2,
|
||||
"sp_size": 2,
|
||||
"num_microbatches": 2,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "split_gather",
|
||||
"enable_flash_attention": False,
|
||||
"sequence_parallelism_mode": "all_to_all",
|
||||
"enable_flash_attention": True,
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 1,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
|
@ -164,7 +178,18 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "all_to_all",
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 2,
|
||||
"zero_stage": 1,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 4,
|
||||
"pp_size": 1,
|
||||
"num_microbatches": 1,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "split_gather",
|
||||
"enable_flash_attention": False,
|
||||
"use_lazy_init": True,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
|
@ -213,7 +238,11 @@ def run_llama_test(test_config):
|
|||
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
try:
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
except Exception as e:
|
||||
print(f"Failed config: {test_config}")
|
||||
raise e
|
||||
|
||||
clear_layout_converter()
|
||||
Randomizer.reset_index()
|
||||
|
@ -263,7 +292,11 @@ def run_llama_3d_test(test_config):
|
|||
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
try:
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
except Exception as e:
|
||||
print(f"Failed config: {test_config}")
|
||||
raise e
|
||||
|
||||
clear_layout_converter()
|
||||
Randomizer.reset_index()
|
||||
|
|
|
@ -217,6 +217,7 @@ def check_qwen2_3d(rank, world_size, port):
|
|||
|
||||
|
||||
@pytest.mark.skipif(transformers.__version__ < "4.39.1", reason="Requires transformers version 4.39.1 or later")
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_qwen2():
|
||||
|
@ -224,6 +225,7 @@ def test_qwen2():
|
|||
|
||||
|
||||
@pytest.mark.skipif(transformers.__version__ < "4.39.1", reason="Requires transformers version 4.39.1 or later")
|
||||
@pytest.mark.largedist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_qwen2_3d():
|
||||
|
|
|
@ -1 +1 @@
|
|||
0.3.8
|
||||
0.3.9
|
||||
|
|
Loading…
Reference in New Issue