mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] update transformers (#5583)
* flash_attention forward upgrade * llama_model_forward * remove useless comment * update the requirements.txt * add the transformers version requirements * remove the LATEST VERSION try * [shardformer] update bloom model (#5518) * update bloom model * remove the version restriction * [shardformer] update_falcon (#5520) * [shardformer] update mistral model (#5511) * [shardformer] update gpt2 (#5502) * [shardformer] update gptj model (#5503) * [shardformer] update opt (#5522) * [shardformer] update t5 model (#5524) * [shardformer] update whisper model (#5529) * [shardformer] update vit model (#5530) * update vit model * remove the output_hidden_states * [shardformer] fix llama modeling * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [zero] support multiple (partial) backward passes (#5596) * [zero] support multiple (partial) backward passes * [misc] update requirements * [zero] support multiple (partial) backward passes (#5596) * [zero] support multiple (partial) backward passes * [misc] update requirements * fix conflicts * [doc] fix ColossalMoE readme (#5599) * fix readme * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * merge with main * merge with main * llama_model_forward * remove useless comment * remove the LATEST VERSION try * [shardformer] update bloom model (#5518) * update bloom model * remove the version restriction * [shardformer] update mistral model (#5511) * [shardformer] update opt (#5522) * [shardformer] update whisper model (#5529) * [shardformer] fix llama modeling * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [hotfix] Fix examples no pad token & auto parallel codegen bug; (#5606) * fix no pad token bug * fixed some auto parallel codegen bug, but might not run on torch 2.1 --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [shardformer] fix pipeline grad ckpt (#5620) * [shardformer] fix pipeline grad ckpt * [shardformer] fix whisper (#5628) * [test] fix llama model test * fix the opt upgrade (#5634) * [shardformer] fix attn replacement (#5636) * [shardformer] update flashattention replacement (#5637) * update transformers update transformers fix fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [test] fix llama test (#5638) * [gemini] fix buffer cast (#5639) * Fix shardformer upgrade (#5640) * fix llama model * fix the mistral * fix the shardformer model * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [shardformer]support pipeline parallelism for mistral. (#5642) * [shardformer] fix attn replacement (#5636) * [shardformer] update flashattention replacement (#5637) * update transformers update transformers fix fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] Support LLaMA-3 CPT and ST (#5619) * support LLaMA-3 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Run pre-commit --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [exampe] update llama example (#5626) * [plugin] support dp inside for hybriad parallel * [example] update llama benchmark * [example] update llama benchmark * [example] update llama readme * [example] update llama readme * [example] llama3 (#5631) * release llama3 * [release] llama3 * [release] llama3 * [release] llama3 * [release] llama3 * [test] fix llama test (#5638) * [gemini] fix buffer cast (#5639) * support pp for mistral * fix * fix fix fix * fix --------- Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tong Li <tong.li352711588@gmail.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> --------- Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Camille Zhong <44392324+Camille7777@users.noreply.github.com> Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu> Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: flybird11111 <1829166702@qq.com> Co-authored-by: Tong Li <tong.li352711588@gmail.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com>pull/5627/head^2
parent
f4c5aafe29
commit
0d0a582033
|
@ -6,6 +6,7 @@ import torch.distributed as dist
|
|||
from torch.distributed import ProcessGroup
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from torch.nn import functional as F
|
||||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
CausalLMOutputWithCrossAttentions,
|
||||
|
@ -205,12 +206,13 @@ class BloomPipelineForwards:
|
|||
alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
|
||||
|
||||
# causal_mask is constructed every stage and its input is passed through different stages
|
||||
causal_mask = self._prepare_attn_mask(
|
||||
causal_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask,
|
||||
input_shape=(batch_size, seq_length),
|
||||
inputs_embeds=hidden_states,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
|
||||
causal_mask = causal_mask.bool()
|
||||
# split the input tensor along sequence dimension
|
||||
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
||||
if shard_config and shard_config.enable_sequence_parallelism:
|
||||
|
@ -227,21 +229,15 @@ class BloomPipelineForwards:
|
|||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
outputs = self._gradient_checkpointing_func(
|
||||
block.__call__,
|
||||
hidden_states,
|
||||
alibi,
|
||||
causal_mask,
|
||||
layer_past,
|
||||
head_mask[i],
|
||||
use_cache,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
outputs = block(
|
||||
|
@ -1002,11 +998,13 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||
|
||||
alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
|
||||
|
||||
causal_mask = self._prepare_attn_mask(
|
||||
causal_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask,
|
||||
input_shape=(batch_size, seq_length),
|
||||
inputs_embeds=hidden_states,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
causal_mask = causal_mask.bool()
|
||||
# split the input tensor along sequence dimension
|
||||
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
||||
hidden_states = split_forward_gather_backward(
|
||||
|
@ -1018,21 +1016,15 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
outputs = self._gradient_checkpointing_func(
|
||||
block.__call__,
|
||||
hidden_states,
|
||||
alibi,
|
||||
causal_mask,
|
||||
layer_past,
|
||||
head_mask[i],
|
||||
use_cache,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
outputs = block(
|
||||
|
|
|
@ -1,9 +1,16 @@
|
|||
import math
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from transformers.modeling_attn_mask_utils import (
|
||||
AttentionMaskConverter,
|
||||
_prepare_4d_causal_attention_mask,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
CausalLMOutputWithCrossAttentions,
|
||||
|
@ -99,11 +106,17 @@ def get_tp_falcon_decoder_layer_forward():
|
|||
hidden_states: torch.Tensor,
|
||||
alibi: Optional[torch.Tensor],
|
||||
attention_mask: torch.Tensor,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
if "padding_mask" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||
)
|
||||
residual = hidden_states
|
||||
|
||||
if self.config.new_decoder_architecture:
|
||||
|
@ -117,10 +130,12 @@ def get_tp_falcon_decoder_layer_forward():
|
|||
attention_layernorm_out,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
alibi=alibi,
|
||||
head_mask=head_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attention_output = attn_outputs[0]
|
||||
|
@ -154,87 +169,6 @@ def get_tp_falcon_decoder_layer_forward():
|
|||
return forward
|
||||
|
||||
|
||||
def get_falcon_flash_attention_forward():
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention as me_attention
|
||||
except:
|
||||
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
|
||||
from transformers.models.falcon.modeling_falcon import FalconAttention
|
||||
|
||||
def forward(
|
||||
self: FalconAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
alibi: Optional[torch.Tensor],
|
||||
attention_mask: torch.Tensor,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
||||
num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
|
||||
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
||||
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
|
||||
|
||||
batch_size, query_length, _, _ = query_layer.shape
|
||||
|
||||
query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim)
|
||||
key_layer = key_layer.transpose(1, 2).reshape(
|
||||
batch_size * num_kv_heads,
|
||||
query_length,
|
||||
self.head_dim,
|
||||
)
|
||||
value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim)
|
||||
|
||||
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
|
||||
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
# concatenate along seq_length dimension:
|
||||
# - key: [batch_size * self.num_heads, kv_length, head_dim]
|
||||
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
||||
key_layer = torch.cat((past_key, key_layer), dim=1)
|
||||
value_layer = torch.cat((past_value, value_layer), dim=1)
|
||||
|
||||
_, kv_length, _ = key_layer.shape
|
||||
if use_cache:
|
||||
present = (key_layer, value_layer)
|
||||
else:
|
||||
present = None
|
||||
|
||||
attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
|
||||
|
||||
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim).transpose(1, 2).contiguous()
|
||||
key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim).transpose(1, 2).contiguous()
|
||||
value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
if alibi is not None:
|
||||
attention_mask_float = (
|
||||
attention_mask_float + alibi.view(batch_size, self.num_heads, 1, kv_length) * self.beta
|
||||
)
|
||||
|
||||
batch_size, src_len = query_layer_.size()[0], query_layer_.size()[1]
|
||||
tgt_len = key_layer_.size()[1]
|
||||
attention_mask_float = attention_mask_float.expand(batch_size, self.num_heads, src_len, tgt_len).contiguous()
|
||||
context_layer = me_attention(
|
||||
query_layer_,
|
||||
key_layer_,
|
||||
value_layer_,
|
||||
attn_bias=attention_mask_float,
|
||||
scale=self.inv_norm_factor,
|
||||
p=self.attention_dropout.p,
|
||||
)
|
||||
batch_size, seq_length, _, _ = context_layer.shape
|
||||
context_layer = context_layer.reshape(batch_size, seq_length, -1)
|
||||
|
||||
output_tensor = self.dense(context_layer)
|
||||
|
||||
return output_tensor, present
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
class FalconPipelineForwards:
|
||||
"""
|
||||
This class serves as a micro library for falcon pipeline forwards.
|
||||
|
@ -246,6 +180,7 @@ class FalconPipelineForwards:
|
|||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
|
@ -274,17 +209,6 @@ class FalconPipelineForwards:
|
|||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if past_key_values is None:
|
||||
past_key_values = tuple([None] * len(self.h))
|
||||
else:
|
||||
past_key_values = self._convert_to_rw_cache(past_key_values)
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape batch_size x num_heads x N x N
|
||||
# head_mask has shape n_layer x batch x num_heads x N x N
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
# case: First stage of training
|
||||
if stage_manager.is_first_stage():
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
|
@ -295,16 +219,22 @@ class FalconPipelineForwards:
|
|||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
else:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
batch_size, seq_length = input_shape
|
||||
|
||||
if past_key_values is None:
|
||||
past_key_values = tuple([None] * len(self.h))
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
presents = () if use_cache else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
|
@ -312,22 +242,80 @@ class FalconPipelineForwards:
|
|||
# Compute alibi tensor: check build_alibi_tensor documentation
|
||||
past_key_values_length = 0
|
||||
if past_key_values[0] is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[1] # 1 because RW-cache, not standard format
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=hidden_states.device)
|
||||
else:
|
||||
attention_mask = attention_mask.to(hidden_states.device)
|
||||
past_key_values_length = past_key_values[0][0].shape[-2]
|
||||
|
||||
if self.use_alibi:
|
||||
alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
|
||||
mask = (
|
||||
torch.ones(
|
||||
(batch_size, seq_length + past_key_values_length), device=inputs_embeds.device, dtype=torch.long
|
||||
)
|
||||
if attention_mask is None
|
||||
else attention_mask
|
||||
)
|
||||
alibi = build_alibi_tensor(mask, self.num_heads, dtype=hidden_states.dtype)
|
||||
else:
|
||||
alibi = None
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0)
|
||||
|
||||
causal_mask = self._prepare_attn_mask(
|
||||
attention_mask,
|
||||
input_shape=(batch_size, seq_length),
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
elif self._use_sdpa and not output_attentions:
|
||||
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
||||
# the manual implementation that requires a 4D causal mask in all cases.
|
||||
if alibi is None:
|
||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
)
|
||||
elif head_mask is None:
|
||||
alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:])
|
||||
|
||||
attention_mask_2d = attention_mask
|
||||
# We don't call _prepare_4d_causal_attention_mask_for_sdpa as we need to mask alibi using the 4D attention_mask untouched.
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
||||
)
|
||||
|
||||
# We take care to integrate alibi bias in the attention_mask here.
|
||||
if attention_mask_2d is None:
|
||||
attention_mask = alibi / math.sqrt(self.config.hidden_size // self.num_heads)
|
||||
else:
|
||||
attention_mask = torch.masked_fill(
|
||||
alibi / math.sqrt(self.config.hidden_size // self.num_heads),
|
||||
attention_mask < -1,
|
||||
torch.finfo(alibi.dtype).min,
|
||||
)
|
||||
|
||||
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
|
||||
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
if seq_length > 1:
|
||||
attention_mask = AttentionMaskConverter._unmask_unattended(
|
||||
attention_mask, attention_mask_2d, unmasked_value=0.0
|
||||
)
|
||||
else:
|
||||
# PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case.
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
||||
)
|
||||
else:
|
||||
# 4d mask is passed through the layers
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
||||
)
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape batch_size x num_heads x N x N
|
||||
# head_mask has shape n_layer x batch x num_heads x N x N
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
for i, (block, layer_past) in enumerate(
|
||||
|
@ -337,31 +325,23 @@ class FalconPipelineForwards:
|
|||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
outputs = self._gradient_checkpointing_func(
|
||||
block.__call__,
|
||||
hidden_states,
|
||||
alibi,
|
||||
causal_mask,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
head_mask[i],
|
||||
layer_past,
|
||||
use_cache,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
attention_mask=causal_mask,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask[i],
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
|
@ -382,9 +362,6 @@ class FalconPipelineForwards:
|
|||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if presents is not None:
|
||||
presents = self._convert_cache_to_standard_format(presents, batch_size)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
|
|
|
@ -177,11 +177,9 @@ class GPT2PipelineForwards:
|
|||
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||
|
||||
if stage_manager.is_first_stage():
|
||||
if position_ids is not None:
|
||||
position_ids = position_ids.view(-1, input_shape[-1])
|
||||
else:
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(0, input_shape[-1], dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||
position_ids = position_ids.unsqueeze(0)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
|
@ -239,22 +237,16 @@ class GPT2PipelineForwards:
|
|||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, use_cache, output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
outputs = self._gradient_checkpointing_func(
|
||||
block.__call__,
|
||||
hidden_states,
|
||||
None,
|
||||
attention_mask,
|
||||
head_mask[i],
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
use_cache,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
outputs = block(
|
||||
|
|
|
@ -148,11 +148,9 @@ class GPTJPipelineForwards:
|
|||
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||
|
||||
# position id to be assigned not just for the first stage for attn input
|
||||
if position_ids is not None:
|
||||
position_ids = position_ids.view(-1, seq_length)
|
||||
else:
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||
position_ids = position_ids.unsqueeze(0)
|
||||
if stage_manager.is_first_stage():
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
|
@ -201,21 +199,15 @@ class GPTJPipelineForwards:
|
|||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, use_cache, output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
outputs = self._gradient_checkpointing_func(
|
||||
block.__call__,
|
||||
hidden_states,
|
||||
None,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
head_mask[i],
|
||||
use_cache,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
outputs = block(
|
||||
|
@ -627,7 +619,9 @@ def get_gptj_flash_attention_forward():
|
|||
value = torch.cat((past_value, value), dim=-2)
|
||||
|
||||
if use_cache is True:
|
||||
present = (key, value)
|
||||
# Note that this cast is quite ugly, but is not implemented before ROPE as the original codebase keeps the key in float32 all along the computation.
|
||||
# Reference: https://github.com/kingoflolz/mesh-transformer-jax/blob/f8315e3003033b23f21d78361b288953064e0e76/mesh_transformer/layers.py#L128
|
||||
present = (key.to(hidden_states.dtype), value)
|
||||
else:
|
||||
present = None
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ import torch.nn.functional as F
|
|||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
|
@ -16,6 +17,8 @@ from transformers.models.llama.modeling_llama import (
|
|||
LlamaForCausalLM,
|
||||
LlamaForSequenceClassification,
|
||||
LlamaModel,
|
||||
_prepare_4d_causal_attention_mask,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
|
@ -31,13 +34,6 @@ from colossalai.shardformer.shard import ShardConfig
|
|||
|
||||
from ..layer import ColoAttention, cross_entropy_1d
|
||||
|
||||
try:
|
||||
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
|
||||
|
||||
LATEST_VERSION = True
|
||||
except ImportError:
|
||||
LATEST_VERSION = False
|
||||
|
||||
|
||||
class LlamaPipelineForwards:
|
||||
"""
|
||||
|
@ -75,13 +71,13 @@ class LlamaPipelineForwards:
|
|||
# retrieve input_ids and inputs_embeds
|
||||
if stage_manager.is_first_stage():
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
batch_size, seq_length = input_ids.shape[:2]
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
batch_size, seq_length, _ = inputs_embeds.shape[:2]
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
@ -111,11 +107,12 @@ class LlamaPipelineForwards:
|
|||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
position_ids = position_ids.unsqueeze(0)
|
||||
|
||||
# embed positions, for the first stage, hidden_states is the input embeddings,
|
||||
# for the other stages, hidden_states is the output of the previous stage
|
||||
|
@ -123,20 +120,32 @@ class LlamaPipelineForwards:
|
|||
# in this case, attention_mask is a dict rather than a tensor
|
||||
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
|
||||
attention_mask = ColoAttention.prepare_attn_kwargs(
|
||||
mask_shape, hidden_states.dtype, hidden_states.device, q_padding_mask=attention_mask, is_causal=True
|
||||
mask_shape,
|
||||
hidden_states.dtype,
|
||||
hidden_states.device,
|
||||
q_padding_mask=attention_mask,
|
||||
is_causal=True,
|
||||
)
|
||||
else:
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
|
||||
)
|
||||
if LATEST_VERSION:
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
elif self._use_sdpa and not output_attentions:
|
||||
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
||||
# the manual implementation that requires a 4D causal mask in all cases.
|
||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
)
|
||||
else:
|
||||
attention_mask = self._prepare_decoder_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
|
||||
# 4d mask is passed through the layers
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
hidden_states,
|
||||
past_key_values_length,
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
@ -149,7 +158,7 @@ class LlamaPipelineForwards:
|
|||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
next_decoder_cache = None
|
||||
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
num_ckpt_layers = 0
|
||||
|
@ -160,7 +169,7 @@ class LlamaPipelineForwards:
|
|||
num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
|
||||
stage=stage_manager.stage,
|
||||
num_layers=end_idx - start_idx,
|
||||
model_chunk_id=stage_manager.model_chunk_id if stage_manager.is_interleave else 0,
|
||||
model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0),
|
||||
)
|
||||
assert num_ckpt_layers <= end_idx - start_idx
|
||||
|
||||
|
@ -168,30 +177,22 @@ class LlamaPipelineForwards:
|
|||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if idx - start_idx < num_ckpt_layers:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, output_attentions, None)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
None,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
@ -199,7 +200,7 @@ class LlamaPipelineForwards:
|
|||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
|
@ -212,7 +213,16 @@ class LlamaPipelineForwards:
|
|||
next_cache = next_decoder_cache if use_cache else None
|
||||
if stage_manager.is_last_stage():
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_cache,
|
||||
all_hidden_states,
|
||||
all_self_attns,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
|
@ -458,23 +468,25 @@ class LlamaPipelineForwards:
|
|||
def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size):
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
||||
|
||||
llama_version = 2
|
||||
try:
|
||||
from transformers.models.llama.modeling_llama import repeat_kv
|
||||
except:
|
||||
warnings.warn("using llamav1, llamav1 hasn't repeat_kv function")
|
||||
llama_version = 1
|
||||
|
||||
def forward(
|
||||
self: LlamaAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[dict] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if "padding_mask" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||
)
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if sp_mode in ["split_gather", "ring"]:
|
||||
|
@ -498,21 +510,23 @@ def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size):
|
|||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
if llama_version == 2:
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
|
||||
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
|
||||
|
@ -573,7 +587,10 @@ def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig):
|
|||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
|
@ -587,7 +604,11 @@ def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig):
|
|||
# in this case, attention_mask is a dict rather than a tensor
|
||||
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
|
||||
attention_mask = ColoAttention.prepare_attn_kwargs(
|
||||
mask_shape, hidden_states.dtype, hidden_states.device, q_padding_mask=attention_mask, is_causal=True
|
||||
mask_shape,
|
||||
hidden_states.dtype,
|
||||
hidden_states.device,
|
||||
q_padding_mask=attention_mask,
|
||||
is_causal=True,
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
@ -918,7 +939,10 @@ def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
|
|||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
|
@ -934,10 +958,12 @@ def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
|
|||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
||||
(batch_size, seq_length_with_past),
|
||||
dtype=torch.bool,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
|
||||
attention_mask = self._prepare_decoder_attention_mask(
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, attention_mask.shape, inputs_embeds, past_key_values_length
|
||||
)
|
||||
|
||||
|
|
|
@ -1,70 +1,606 @@
|
|||
from typing import Optional, Tuple
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from transformers.cache_utils import Cache, DynamicCache
|
||||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
SequenceClassifierOutputWithPast,
|
||||
)
|
||||
from transformers.models.mistral.modeling_mistral import MistralForCausalLM, MistralModel
|
||||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from ..layer import ColoAttention
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_mistral_flash_attention_forward():
|
||||
class MistralForwards:
|
||||
@staticmethod
|
||||
def mistral_model_forward(
|
||||
self: MistralModel,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
if use_cache:
|
||||
logger.warning_once("use_cache=True is not supported for Mistral models at the moment.")
|
||||
use_cache = False
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if stage_manager.is_first_stage():
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
batch_size, seq_length = input_shape
|
||||
device = hidden_states.device
|
||||
|
||||
past_key_values_length = 0
|
||||
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
if attention_mask is not None and self._use_flash_attention_2 and use_cache:
|
||||
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
||||
if is_padding_right:
|
||||
raise ValueError(
|
||||
"You are attempting to perform batched generation with padding_side='right'"
|
||||
" this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
|
||||
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
||||
)
|
||||
|
||||
if shard_config.enable_flash_attention:
|
||||
# in this case, attention_mask is a dict rather than a tensor
|
||||
mask_shape = (batch_size, 1, seq_length, seq_length)
|
||||
attention_mask = ColoAttention.prepare_attn_kwargs(
|
||||
mask_shape,
|
||||
hidden_states.dtype,
|
||||
hidden_states.device,
|
||||
q_padding_mask=attention_mask,
|
||||
is_causal=True,
|
||||
)
|
||||
else:
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
else:
|
||||
# 4d mask is passed through the layers
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
hidden_states,
|
||||
past_key_values_length,
|
||||
sliding_window=self.config.sliding_window,
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
num_ckpt_layers = 0
|
||||
if self.gradient_checkpointing and self.training:
|
||||
num_ckpt_layers = end_idx - start_idx
|
||||
# TODO: We can replace `gradient_checkpointing_enable` fn and initialize a gradient_checkpointing (List[bool]) for each layer
|
||||
if shard_config.gradient_checkpoint_config is not None:
|
||||
num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
|
||||
stage=stage_manager.stage,
|
||||
num_layers=end_idx - start_idx,
|
||||
model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0),
|
||||
)
|
||||
assert num_ckpt_layers <= end_idx - start_idx
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if idx - start_idx < num_ckpt_layers:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
layer_outputs[2 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = None
|
||||
if stage_manager.is_last_stage():
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
else:
|
||||
return {"hidden_states": hidden_states}
|
||||
|
||||
@staticmethod
|
||||
def mistral_for_causal_lm_forward(
|
||||
self: MistralForCausalLM,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, MistralForCausalLM
|
||||
|
||||
>>> model = MistralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = MistralForwards.mistral_model_forward(
|
||||
self.model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config,
|
||||
)
|
||||
|
||||
past_key_values = None
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
else:
|
||||
hidden_states = outputs.get("hidden_states")
|
||||
return {"hidden_states": hidden_states}
|
||||
|
||||
@staticmethod
|
||||
def mistral_for_sequence_classification_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
transformer_outputs = MistralForwards.mistral_model_forward(
|
||||
self.model,
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config,
|
||||
)
|
||||
|
||||
if input_ids is not None:
|
||||
batch_size = input_ids.shape[0]
|
||||
elif inputs_embeds is not None:
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
else:
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = transformer_outputs[0]
|
||||
logits = self.score(hidden_states)
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
else:
|
||||
if input_ids is not None:
|
||||
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
|
||||
logits.device
|
||||
)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
labels = labels.to(logits.device)
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
if self.num_labels == 1:
|
||||
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
if not return_dict:
|
||||
output = (pooled_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
else:
|
||||
hidden_states = transformer_outputs.get("hidden_states")
|
||||
return {"hidden_states": hidden_states}
|
||||
|
||||
return SequenceClassifierOutputWithPast(
|
||||
loss=loss,
|
||||
logits=pooled_logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def get_mistral_model_forward_for_flash_attn(shard_config: ShardConfig):
|
||||
logger = logging.get_logger(__name__)
|
||||
assert shard_config.enable_flash_attention, "Flash Attention is not enabled."
|
||||
|
||||
def forward(
|
||||
self: MistralModel,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
|
||||
past_key_values_length = 0
|
||||
|
||||
if use_cache:
|
||||
use_legacy_cache = not isinstance(past_key_values, Cache)
|
||||
if use_legacy_cache:
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
||||
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if attention_mask is not None and self._use_flash_attention_2 and use_cache:
|
||||
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
||||
if is_padding_right:
|
||||
raise ValueError(
|
||||
"You are attempting to perform batched generation with padding_side='right'"
|
||||
" this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
|
||||
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
||||
)
|
||||
if shard_config.enable_flash_attention:
|
||||
# in this case, attention_mask is a dict rather than a tensor
|
||||
mask_shape = (batch_size, 1, seq_length, seq_length)
|
||||
attention_mask = ColoAttention.prepare_attn_kwargs(
|
||||
mask_shape,
|
||||
inputs_embeds.dtype,
|
||||
inputs_embeds.device,
|
||||
q_padding_mask=attention_mask,
|
||||
is_causal=True,
|
||||
)
|
||||
else:
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
else:
|
||||
# 4d mask is passed through the layers
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
sliding_window=self.config.sliding_window,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = None
|
||||
if use_cache:
|
||||
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_mistral_flash_attention_forward(shard_config: ShardConfig):
|
||||
from transformers.models.mistral.modeling_mistral import MistralAttention, apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
|
||||
|
||||
def forward(
|
||||
self: MistralAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if "padding_mask" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||
)
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
|
||||
|
||||
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = (
|
||||
self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
)
|
||||
value_states = (
|
||||
self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
)
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
me_input_shape = (bsz, q_len, self.num_heads, self.head_dim)
|
||||
query_states = query_states.transpose(1, 2).contiguous().view(*me_input_shape)
|
||||
key_states = key_states.transpose(1, 2).contiguous().view(*me_input_shape)
|
||||
value_states = value_states.transpose(1, 2).contiguous().view(*me_input_shape)
|
||||
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
|
||||
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
|
||||
|
||||
flash_attention_mask = None
|
||||
attn_mask_type = AttnMaskType.causal
|
||||
if attention_mask != None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
|
||||
attn_mask_type = AttnMaskType.paddedcausal
|
||||
|
||||
attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
|
||||
attn_output = attention(
|
||||
query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type
|
||||
)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ from typing import List, Optional, Tuple, Union
|
|||
|
||||
import torch
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
|
@ -42,7 +43,7 @@ def _get_attention_mask(
|
|||
is_causal=True,
|
||||
)
|
||||
else:
|
||||
attention_mask = self.decoder._prepare_decoder_attention_mask(
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
hidden_states,
|
||||
|
@ -57,6 +58,20 @@ class OPTPipelineForwards:
|
|||
under pipeline setting.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
"""
|
||||
bsz, src_len = mask.size()
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
|
||||
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||
|
||||
@staticmethod
|
||||
def opt_model_forward(
|
||||
self: OPTModel,
|
||||
|
@ -112,7 +127,7 @@ class OPTPipelineForwards:
|
|||
inputs_embeds = decoder.project_in(inputs_embeds)
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
inputs_embeds.dtype
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
if hidden_states is None:
|
||||
raise ValueError("hidden_states shouldn't be None for intermediate stages.")
|
||||
|
@ -125,12 +140,25 @@ class OPTPipelineForwards:
|
|||
# required mask seq length can be calculated via length of past
|
||||
mask_seq_length = past_key_values_length + seq_length
|
||||
# embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(batch_size, mask_seq_length, device=device)
|
||||
elif attention_mask.shape[1] != mask_seq_length:
|
||||
raise ValueError(
|
||||
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
|
||||
f"{mask_seq_length} (sum of the lengths of current and past inputs)"
|
||||
if self.decoder._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
attention_mask = (
|
||||
torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
|
||||
if attention_mask is None
|
||||
else attention_mask
|
||||
)
|
||||
else:
|
||||
# 4d mask is passed through the layers
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
|
||||
elif attention_mask.shape[1] != mask_seq_length:
|
||||
raise ValueError(
|
||||
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
|
||||
f"{mask_seq_length} (sum of the lengths of current and past inputs)"
|
||||
)
|
||||
causal_attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, input_shape, hidden_states, past_key_values_length
|
||||
)
|
||||
|
||||
if stage_manager.is_first_stage():
|
||||
|
@ -205,20 +233,14 @@ class OPTPipelineForwards:
|
|||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if decoder.gradient_checkpointing and decoder.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, output_attentions, None)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
causal_attention_mask,
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
None,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
|
|
|
@ -3,7 +3,6 @@ from typing import Dict, List, Optional, Tuple, Union
|
|||
|
||||
import torch
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
|
@ -118,16 +117,13 @@ class T5PipelineForwards:
|
|||
# required mask seq length can be calculated via length of past
|
||||
mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(batch_size, mask_seq_length, device=device)
|
||||
if in_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
|
||||
encoder_seq_length = encoder_hidden_states.shape[1]
|
||||
encoder_attention_mask = torch.ones(batch_size, encoder_seq_length, device=device, dtype=torch.long)
|
||||
|
||||
# initialize past_key_values with `None` if past does not exist
|
||||
if past_key_values is None:
|
||||
past_key_values = [None] * len(self.block)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(batch_size, mask_seq_length, device=device)
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
@ -138,7 +134,7 @@ class T5PipelineForwards:
|
|||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long)
|
||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
else:
|
||||
encoder_extended_attention_mask = None
|
||||
|
@ -162,15 +158,8 @@ class T5PipelineForwards:
|
|||
torch.cuda.set_device(hidden_states.device)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return tuple(module(*inputs, use_cache, output_attentions))
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = checkpoint(
|
||||
create_custom_forward(layer_module),
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.forward,
|
||||
hidden_states,
|
||||
extended_attention_mask,
|
||||
position_bias,
|
||||
|
@ -180,6 +169,8 @@ class T5PipelineForwards:
|
|||
layer_head_mask,
|
||||
cross_attn_layer_head_mask,
|
||||
None, # past_key_value is always None with gradient checkpointing
|
||||
use_cache,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(
|
||||
|
|
|
@ -14,6 +14,8 @@ def _encoder_forward(
|
|||
end_idx: int,
|
||||
hidden_states: torch.Tensor,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
stage_manager: PipelineStageManager = None,
|
||||
) -> Union[tuple, BaseModelOutput]:
|
||||
|
@ -23,20 +25,14 @@ def _encoder_forward(
|
|||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
if encoder.gradient_checkpointing and encoder.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs, False)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(layer_module),
|
||||
layer_outputs = encoder._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
layer_head_mask,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(hidden_states, layer_head_mask, False)
|
||||
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if not stage_manager.is_last_stage():
|
||||
|
@ -114,6 +110,8 @@ def ViTModel_pipeline_forward(stage_manager: PipelineStageManager, stage_index:
|
|||
end_idx=stage_index[1],
|
||||
hidden_states=hidden_states,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
)
|
||||
|
|
|
@ -5,6 +5,10 @@ from typing import List, Optional, Tuple, Union
|
|||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers.modeling_attn_mask_utils import (
|
||||
_prepare_4d_causal_attention_mask,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
|
@ -35,6 +39,8 @@ def _get_attention_mask(
|
|||
hidden_states: torch.Tensor,
|
||||
past_key_values_length: int,
|
||||
attention_mask: Optional[torch.FloatTensor],
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
batch_size, seq_length = hidden_states.shape[:2]
|
||||
mask_seq_length = past_key_values_length + seq_length
|
||||
|
@ -47,12 +53,20 @@ def _get_attention_mask(
|
|||
is_causal=True,
|
||||
)
|
||||
else:
|
||||
attention_mask = self._prepare_decoder_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
hidden_states,
|
||||
past_key_values_length,
|
||||
)
|
||||
input_shape = (batch_size, seq_length)
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
elif self._use_sdpa and head_mask is None and not output_attentions:
|
||||
# output_attentions=True & head_mask can not be supported when using SDPA.
|
||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask, input_shape, hidden_states, past_key_values_length
|
||||
)
|
||||
else:
|
||||
# 4d mask is passed through the layers
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, input_shape, hidden_states, past_key_values_length
|
||||
)
|
||||
return attention_mask
|
||||
|
||||
|
||||
|
@ -539,18 +553,12 @@ class WhisperPipelineForwards:
|
|||
layer_outputs = (None, None)
|
||||
else:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs, output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(encoder_layer),
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
None,
|
||||
(head_mask[idx] if head_mask is not None else None),
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = encoder_layer(
|
||||
|
@ -702,20 +710,16 @@ class WhisperPipelineForwards:
|
|||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
attention_mask = _get_attention_mask(
|
||||
self, shard_config, inputs_embeds, past_key_values_length, attention_mask
|
||||
)
|
||||
|
||||
# embed positions
|
||||
if input_ids is not None:
|
||||
positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
|
||||
else:
|
||||
positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length)
|
||||
|
||||
attention_mask = _get_attention_mask(
|
||||
self,
|
||||
shard_config,
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
attention_mask,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds + positions
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
|
||||
|
@ -732,7 +736,6 @@ class WhisperPipelineForwards:
|
|||
"hidden_states shouldn't be None for stages other than the first stage of encoder/decoder."
|
||||
)
|
||||
input_shape = hidden_states.size()[:-1]
|
||||
|
||||
attention_mask = _get_attention_mask(
|
||||
self,
|
||||
shard_config,
|
||||
|
@ -756,16 +759,8 @@ class WhisperPipelineForwards:
|
|||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, output_attentions, use_cache)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
|
@ -773,6 +768,8 @@ class WhisperPipelineForwards:
|
|||
head_mask[idx] if head_mask is not None else None,
|
||||
(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
|
||||
None, # past_key_value
|
||||
output_attentions,
|
||||
use_cache,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
|
|
|
@ -24,12 +24,6 @@ from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDe
|
|||
class BloomPolicy(Policy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
import transformers
|
||||
from packaging.version import Version
|
||||
|
||||
assert Version(transformers.__version__) <= Version(
|
||||
"4.33.0"
|
||||
), "The Bloom model should run on a transformers version not greater than 4.33.0."
|
||||
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
|
|
|
@ -7,12 +7,7 @@ from torch.nn import Module
|
|||
|
||||
import colossalai.shardformer.layer as col_nn
|
||||
|
||||
from ..modeling.falcon import (
|
||||
FalconPipelineForwards,
|
||||
build_falcon_alibi_tensor_fn,
|
||||
get_falcon_flash_attention_forward,
|
||||
get_tp_falcon_decoder_layer_forward,
|
||||
)
|
||||
from ..modeling.falcon import FalconPipelineForwards, build_falcon_alibi_tensor_fn, get_tp_falcon_decoder_layer_forward
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ["FalconPolicy"]
|
||||
|
@ -21,12 +16,6 @@ __all__ = ["FalconPolicy"]
|
|||
class FalconPolicy(Policy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
import transformers
|
||||
from packaging.version import Version
|
||||
|
||||
assert Version(transformers.__version__) <= Version(
|
||||
"4.33.0"
|
||||
), "The Falcon model should run on a transformers version not greater than 4.33.0."
|
||||
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
|
@ -36,7 +25,7 @@ class FalconPolicy(Policy):
|
|||
return self.model
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.falcon.modeling_falcon import FalconAttention, FalconDecoderLayer, FalconModel
|
||||
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel
|
||||
|
||||
if not self.model.config.new_decoder_architecture and self.model.config.multi_query:
|
||||
warnings.warn(
|
||||
|
@ -147,11 +136,8 @@ class FalconPolicy(Policy):
|
|||
)
|
||||
|
||||
if self.shard_config.enable_flash_attention:
|
||||
self.append_or_create_method_replacement(
|
||||
description={"forward": get_falcon_flash_attention_forward()},
|
||||
policy=policy,
|
||||
target_key=FalconAttention,
|
||||
)
|
||||
warnings.warn("Falcon doesn't support flash attention now, fallback to transformers attention.")
|
||||
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
|
|
|
@ -35,13 +35,20 @@ class GPT2Policy(Policy):
|
|||
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
||||
"""
|
||||
self.tie_weight = self.tie_weight_check()
|
||||
self.origin_attn_implement = self.model.config._attn_implementation
|
||||
return self.model
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model
|
||||
|
||||
ATTN_IMPLEMENTATION = {
|
||||
"eager": GPT2Attention,
|
||||
}
|
||||
|
||||
policy = {}
|
||||
|
||||
attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
|
||||
|
||||
embedding_cls = None
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
embedding_cls = col_nn.VocabParallelEmbedding1D
|
||||
|
@ -186,7 +193,7 @@ class GPT2Policy(Policy):
|
|||
"forward": get_gpt2_flash_attention_forward(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=GPT2Attention,
|
||||
target_key=attn_cls,
|
||||
)
|
||||
if not self.shard_config.pipeline_stage_manager:
|
||||
policy[GPT2Model].method_replacement = {
|
||||
|
|
|
@ -30,13 +30,20 @@ class GPTJPolicy(Policy):
|
|||
|
||||
def preprocess(self):
|
||||
self.tie_weight = self.tie_weight_check()
|
||||
self.origin_attn_implement = self.model.config._attn_implementation
|
||||
return self.model
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJBlock, GPTJModel
|
||||
|
||||
ATTN_IMPLEMENTATION = {
|
||||
"eager": GPTJAttention,
|
||||
}
|
||||
|
||||
policy = {}
|
||||
|
||||
attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
|
||||
|
||||
embedding_cls = None
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
embedding_cls = col_nn.VocabParallelEmbedding1D
|
||||
|
@ -160,7 +167,7 @@ class GPTJPolicy(Policy):
|
|||
"forward": get_gptj_flash_attention_forward(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=GPTJAttention,
|
||||
target_key=attn_cls,
|
||||
)
|
||||
if not self.shard_config.pipeline_stage_manager:
|
||||
self.append_or_create_method_replacement(
|
||||
|
|
|
@ -36,13 +36,26 @@ class LlamaPolicy(Policy):
|
|||
|
||||
def preprocess(self):
|
||||
self.tie_weight = self.tie_weight_check()
|
||||
self.origin_attn_implement = self.model.config._attn_implementation
|
||||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
LlamaDecoderLayer,
|
||||
LlamaFlashAttention2,
|
||||
LlamaModel,
|
||||
LlamaSdpaAttention,
|
||||
)
|
||||
|
||||
ATTN_IMPLEMENTATION = {
|
||||
"eager": LlamaAttention,
|
||||
"flash_attention_2": LlamaFlashAttention2,
|
||||
"sdpa": LlamaSdpaAttention,
|
||||
}
|
||||
policy = {}
|
||||
|
||||
attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
|
||||
embedding_cls = None
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
embedding_cls = VocabParallelEmbedding1D
|
||||
|
@ -93,7 +106,7 @@ class LlamaPolicy(Policy):
|
|||
"forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=LlamaAttention,
|
||||
target_key=attn_cls,
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
decoder_attribute_replacement = {
|
||||
|
@ -102,7 +115,7 @@ class LlamaPolicy(Policy):
|
|||
if getattr(self.model.config, "num_key_value_heads", False):
|
||||
decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size
|
||||
|
||||
policy[LlamaAttention] = ModulePolicyDescription(
|
||||
policy[attn_cls] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
)
|
||||
self.append_or_create_method_replacement(
|
||||
|
@ -110,7 +123,7 @@ class LlamaPolicy(Policy):
|
|||
"forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=LlamaAttention,
|
||||
target_key=attn_cls,
|
||||
)
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
|
@ -221,7 +234,7 @@ class LlamaPolicy(Policy):
|
|||
"forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_group, sp_size),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=LlamaAttention,
|
||||
target_key=attn_cls,
|
||||
)
|
||||
if self.pipeline_stage_manager is None:
|
||||
# replace llama model forward method
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
import warnings
|
||||
from typing import Dict, Union
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, List, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
|
||||
from colossalai.shardformer.layer import (
|
||||
FusedRMSNorm,
|
||||
|
@ -13,7 +16,11 @@ from colossalai.shardformer.layer import (
|
|||
VocabParallelLMHead1D,
|
||||
)
|
||||
|
||||
from ..modeling.mistral import get_mistral_flash_attention_forward
|
||||
from ..modeling.mistral import (
|
||||
MistralForwards,
|
||||
get_mistral_flash_attention_forward,
|
||||
get_mistral_model_forward_for_flash_attn,
|
||||
)
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ["MistralPolicy", "MistralModelPolicy", "MistralForCausalLMPolicy", "MistralForSequenceClassificationPolicy"]
|
||||
|
@ -25,13 +32,26 @@ class MistralPolicy(Policy):
|
|||
|
||||
def preprocess(self):
|
||||
self.tie_weight = self.tie_weight_check()
|
||||
self.origin_attn_implement = self.model.config._attn_implementation
|
||||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
from transformers.models.mistral.modeling_mistral import MistralAttention, MistralDecoderLayer, MistralModel
|
||||
from transformers.models.mistral.modeling_mistral import (
|
||||
MistralAttention,
|
||||
MistralDecoderLayer,
|
||||
MistralFlashAttention2,
|
||||
MistralModel,
|
||||
)
|
||||
|
||||
ATTN_IMPLEMENTATION = {
|
||||
"eager": MistralAttention,
|
||||
"flash_attention_2": MistralFlashAttention2,
|
||||
}
|
||||
|
||||
policy = {}
|
||||
|
||||
attn_cls = ATTN_IMPLEMENTATION[self.model.config._attn_implementation]
|
||||
|
||||
embedding_cls = None
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
embedding_cls = VocabParallelEmbedding1D
|
||||
|
@ -127,27 +147,112 @@ class MistralPolicy(Policy):
|
|||
if self.shard_config.enable_flash_attention:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_mistral_flash_attention_forward(),
|
||||
"forward": get_mistral_flash_attention_forward(self.shard_config),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=MistralAttention,
|
||||
target_key=attn_cls,
|
||||
)
|
||||
if self.pipeline_stage_manager is None:
|
||||
# replace llama model forward method
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_mistral_model_forward_for_flash_attn(self.shard_config),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=MistralModel,
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
|
||||
"""If under pipeline parallel setting, replacing the original forward method of huggingface
|
||||
to customized forward method, and add this changing to policy."""
|
||||
if self.pipeline_stage_manager is None:
|
||||
return
|
||||
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if self.model.__class__.__name__ == "MistralModel":
|
||||
module = self.model
|
||||
else:
|
||||
module = self.model.model
|
||||
|
||||
if stage_manager.is_interleave:
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage)
|
||||
method_replacement = {
|
||||
"forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config)
|
||||
}
|
||||
|
||||
else:
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
stage_index = stage_manager.get_stage_index(layers_per_stage)
|
||||
method_replacement = {
|
||||
"forward": partial(
|
||||
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
|
||||
)
|
||||
}
|
||||
|
||||
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
assert self.pipeline_stage_manager is not None
|
||||
|
||||
if self.model.__class__.__name__ == "MistralModel":
|
||||
module = self.model
|
||||
else:
|
||||
module = self.model.model
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
held_layers = []
|
||||
if stage_manager.is_interleave:
|
||||
assert stage_manager.num_model_chunks is not None
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
stage_indices = stage_manager.get_stage_index(layers_per_stage)
|
||||
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||
held_layers.append(module.embed_tokens)
|
||||
for start_idx, end_idx in stage_indices:
|
||||
held_layers.extend(module.layers[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(module.norm)
|
||||
|
||||
else:
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.embed_tokens)
|
||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||
held_layers.extend(module.layers[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.norm)
|
||||
return held_layers
|
||||
|
||||
|
||||
class MistralModelPolicy(MistralPolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
if self.pipeline_stage_manager:
|
||||
warnings.warn("Mistral doesn't support pipeline parallelism now.")
|
||||
policy = super().module_policy()
|
||||
from transformers.models.mistral.modeling_mistral import MistralModel
|
||||
|
||||
return super().module_policy()
|
||||
if self.pipeline_stage_manager:
|
||||
self.set_pipeline_forward(
|
||||
model_cls=MistralModel, new_forward=MistralForwards.mistral_model_forward, policy=policy
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
held_layers = super().get_held_layers()
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
"""No shared params in mistral model"""
|
||||
return []
|
||||
|
||||
|
||||
class MistralForCausalLMPolicy(MistralPolicy):
|
||||
|
@ -155,8 +260,6 @@ class MistralForCausalLMPolicy(MistralPolicy):
|
|||
from transformers import MistralForCausalLM
|
||||
|
||||
policy = super().module_policy()
|
||||
if self.pipeline_stage_manager:
|
||||
warnings.warn("Mistral doesn't support pipeline parallelism now.")
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for casual lm
|
||||
|
@ -189,8 +292,38 @@ class MistralForCausalLMPolicy(MistralPolicy):
|
|||
|
||||
policy.update(new_item)
|
||||
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
self.set_pipeline_forward(
|
||||
model_cls=MistralForCausalLM, new_forward=MistralForwards.mistral_for_causal_lm_forward, policy=policy
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.lm_head)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
mistral_model = self.model.model
|
||||
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
|
||||
if (
|
||||
id(mistral_model.embed_tokens.weight) == id(self.model.lm_head.weight)
|
||||
and self.pipeline_stage_manager.num_stages > 1
|
||||
):
|
||||
# tie weights
|
||||
return [
|
||||
{
|
||||
0: mistral_model.embed_tokens.weight,
|
||||
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
|
||||
}
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
class MistralForSequenceClassificationPolicy(MistralPolicy):
|
||||
def module_policy(self):
|
||||
|
@ -209,9 +342,26 @@ class MistralForSequenceClassificationPolicy(MistralPolicy):
|
|||
]
|
||||
)
|
||||
}
|
||||
|
||||
if self.pipeline_stage_manager:
|
||||
warnings.warn("Mistral doesn't support pipeline parallelism now.")
|
||||
|
||||
policy.update(new_item)
|
||||
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
self.set_pipeline_forward(
|
||||
model_cls=MistralForSequenceClassification,
|
||||
new_forward=MistralForwards.mistral_for_sequence_classification_forward,
|
||||
policy=policy,
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.score)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
"""No shared params in llama for sequence classification model"""
|
||||
return []
|
||||
|
|
|
@ -38,26 +38,27 @@ __all__ = [
|
|||
class OPTPolicy(Policy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
import transformers
|
||||
from packaging.version import Version
|
||||
|
||||
# TODO: remove this version check when transformers>=4.36.0
|
||||
assert Version(transformers.__version__) <= Version(
|
||||
"4.33.0"
|
||||
), "The OPT model should run on a transformers version not greater than 4.33.0."
|
||||
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
|
||||
def preprocess(self):
|
||||
self.tie_weight = self.tie_weight_check()
|
||||
self.origin_attn_implement = self.model.config._attn_implementation
|
||||
return self.model
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer
|
||||
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer, OptFlashAttention2
|
||||
|
||||
ATTN_IMPLEMENTATION = {
|
||||
"eager": OPTAttention,
|
||||
"flash_attention_2": OptFlashAttention2,
|
||||
}
|
||||
|
||||
policy = {}
|
||||
|
||||
attn_cls = ATTN_IMPLEMENTATION[self.model.config._attn_implementation]
|
||||
|
||||
embedding_cls = None
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
embedding_cls = VocabParallelEmbedding1D
|
||||
|
@ -88,7 +89,7 @@ class OPTPolicy(Policy):
|
|||
]
|
||||
)
|
||||
|
||||
policy[OPTAttention] = ModulePolicyDescription(
|
||||
policy[attn_cls] = ModulePolicyDescription(
|
||||
attribute_replacement={
|
||||
"embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
|
@ -158,7 +159,7 @@ class OPTPolicy(Policy):
|
|||
"forward": get_opt_flash_attention_forward(self.shard_config),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=OPTAttention,
|
||||
target_key=attn_cls,
|
||||
)
|
||||
if not self.shard_config.pipeline_stage_manager:
|
||||
self.append_or_create_method_replacement(
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
import warnings
|
||||
|
||||
import colossalai.shardformer.layer as col_nn
|
||||
|
||||
from ..modeling.sam import forward_fn, get_sam_flash_attention_forward, get_sam_vision_flash_attention_forward
|
||||
from ..modeling.sam import forward_fn
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ["SamPolicy", "SamModelPolicy"]
|
||||
|
@ -15,7 +17,6 @@ class SamPolicy(Policy):
|
|||
|
||||
def module_policy(self):
|
||||
from transformers.models.sam.modeling_sam import (
|
||||
SamAttention,
|
||||
SamTwoWayAttentionBlock,
|
||||
SamTwoWayTransformer,
|
||||
SamVisionAttention,
|
||||
|
@ -210,20 +211,21 @@ class SamPolicy(Policy):
|
|||
|
||||
# use flash attention
|
||||
if self.shard_config.enable_flash_attention:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_sam_flash_attention_forward(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=SamAttention,
|
||||
)
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_sam_vision_flash_attention_forward(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=SamVisionAttention,
|
||||
)
|
||||
warnings.warn("Flash attention is not supported in SAM model. Fallback to normal attention.")
|
||||
# self.append_or_create_method_replacement(
|
||||
# description={
|
||||
# "forward": get_sam_flash_attention_forward(),
|
||||
# },
|
||||
# policy=policy,
|
||||
# target_key=SamAttention,
|
||||
# )
|
||||
# self.append_or_create_method_replacement(
|
||||
# description={
|
||||
# "forward": get_sam_vision_flash_attention_forward(),
|
||||
# },
|
||||
# policy=policy,
|
||||
# target_key=SamVisionAttention,
|
||||
# )
|
||||
|
||||
return policy
|
||||
|
||||
|
|
|
@ -29,13 +29,6 @@ __all__ = [
|
|||
class WhisperPolicy(Policy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
import transformers
|
||||
from packaging.version import Version
|
||||
|
||||
# TODO: remove this version check when transformers>=4.36.0
|
||||
assert Version(transformers.__version__) <= Version(
|
||||
"4.33.0"
|
||||
), "The Whisper model should run on a transformers version not greater than 4.33.0."
|
||||
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
|
@ -55,6 +48,8 @@ class WhisperPolicy(Policy):
|
|||
WhisperDecoderLayer,
|
||||
WhisperEncoder,
|
||||
WhisperEncoderLayer,
|
||||
WhisperFlashAttention2,
|
||||
WhisperSdpaAttention,
|
||||
)
|
||||
|
||||
policy = {}
|
||||
|
@ -249,6 +244,20 @@ class WhisperPolicy(Policy):
|
|||
policy=policy,
|
||||
target_key=WhisperAttention,
|
||||
)
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_whisper_flash_attention_forward(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=WhisperFlashAttention2,
|
||||
)
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_whisper_flash_attention_forward(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=WhisperSdpaAttention,
|
||||
)
|
||||
if not self.shard_config.pipeline_stage_manager:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
|
|
|
@ -840,6 +840,7 @@ class GeminiDDP(ModelWrapper):
|
|||
for buffer in self.module.buffers():
|
||||
if isinstance(buffer, LazyTensor):
|
||||
buffer.materialize()
|
||||
for buffer in self.module.buffers():
|
||||
buffer.data = buffer.to(get_accelerator().get_current_device())
|
||||
if torch.is_floating_point(buffer):
|
||||
buffer.data = buffer.to(self.mixed_precision)
|
||||
|
|
|
@ -3,7 +3,6 @@ pytest
|
|||
coverage==7.2.3
|
||||
git+https://github.com/hpcaitech/pytest-testmon
|
||||
torchvision
|
||||
transformers==4.33.0
|
||||
timm
|
||||
titans
|
||||
torchaudio
|
||||
|
|
|
@ -16,3 +16,4 @@ ray
|
|||
sentencepiece
|
||||
google
|
||||
protobuf
|
||||
transformers==4.36.2
|
||||
|
|
|
@ -64,7 +64,6 @@ if HAS_LLAMA:
|
|||
intermediate_size=64,
|
||||
num_attention_heads=4,
|
||||
max_position_embeddings=128,
|
||||
num_labels=16,
|
||||
)
|
||||
|
||||
if hasattr(config, "pad_token_id"):
|
||||
|
|
|
@ -52,6 +52,9 @@ config = MistralConfig(
|
|||
hidden_size=256, intermediate_size=256, num_attention_heads=64, num_hidden_layers=2, vocab_size=50258
|
||||
)
|
||||
|
||||
if hasattr(config, "pad_token_id"):
|
||||
config.pad_token_id = config.eos_token_id
|
||||
|
||||
model_zoo.register(
|
||||
name="transformers_mistral",
|
||||
model_fn=lambda: transformers.MistralModel(config),
|
||||
|
|
|
@ -32,7 +32,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
model_fn, loss_fn, test_config
|
||||
)
|
||||
if enable_gradient_checkpointing:
|
||||
org_model.gradient_checkpointing_enable()
|
||||
# org_model.gradient_checkpointing_enable()
|
||||
sharded_model.unwrap().gradient_checkpointing_enable()
|
||||
|
||||
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
|
||||
|
|
|
@ -91,7 +91,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
# check weights
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
if test_config["precision"] == "fp32":
|
||||
atol, rtol = 1e-4, 1e-3
|
||||
atol, rtol = 2e-4, 1e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
check_weight(
|
||||
|
@ -114,6 +114,24 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
{
|
||||
"tp_size": 1,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 2,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 2,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": True,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 4,
|
||||
"pp_size": 1,
|
||||
|
@ -156,7 +174,6 @@ def check_mistral(rank, world_size, port):
|
|||
run_mistral_test()
|
||||
|
||||
|
||||
@pytest.mark.skip("This test should be run on a version of transformers not less than 4.35.2.")
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
|
|
|
@ -116,7 +116,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
"num_microbatches": 2,
|
||||
"enable_metadata_cache": False,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": True,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp32",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
|
|
Loading…
Reference in New Issue