Browse Source

merge model and attention forward

pull/5818/head
GuangyaoZhang 5 months ago
parent
commit
3c7302ad0e
  1. 242
      colossalai/shardformer/modeling/command.py
  2. 24
      colossalai/shardformer/policies/command.py

242
colossalai/shardformer/modeling/command.py

@ -1,5 +1,4 @@
import math
import warnings
from typing import List, Optional, Tuple, Union
import torch
@ -334,217 +333,6 @@ class CommandPipelineForwards:
return {"hidden_states": hidden_states}
def get_command_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size):
from transformers.models.cohere.modeling_cohere import CohereAttention, apply_rotary_pos_emb, repeat_kv
def forward(
self: CohereAttention,
hidden_states: torch.Tensor,
attention_mask: Optional[dict] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
bsz, q_len, _ = hidden_states.size()
if sp_mode in ["split_gather", "ring"]:
q_len *= sp_size
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
query_states = all_to_all_comm(query_states, sp_group)
key_states = all_to_all_comm(key_states, sp_group)
value_states = all_to_all_comm(value_states, sp_group)
bsz, q_len, _ = query_states.size()
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
return forward
def get_command_model_forward_for_flash_attn(shard_config: ShardConfig):
logger = logging.get_logger(__name__)
assert shard_config.enable_flash_attention, "Flash Attention is not enabled."
def forward(
self: CohereModel,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
past_seen_tokens = 0
if use_cache: # kept for BC (cache positions)
if not isinstance(past_key_values, StaticCache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()
if cache_position is None:
if isinstance(past_key_values, StaticCache):
raise ValueError("cache_position is a required argument when using StaticCache.")
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# embed positions
hidden_states = inputs_embeds
# in this case, attention_mask is a dict rather than a tensor
mask_shape = (hidden_states.shape[0], 1, past_seen_tokens, past_seen_tokens)
attention_mask = ColoAttention.prepare_attn_kwargs(
mask_shape,
hidden_states.dtype,
hidden_states.device,
q_padding_mask=attention_mask,
is_causal=True,
)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = None
if use_cache:
next_cache = (
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
)
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
return forward
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
from transformers import CohereForCausalLM
@ -647,7 +435,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
return forward
def get_command_seq_parallel_attention_forward(sp_mode, sp_size, sp_group):
def get_command_seq_parallel_attention_forward(sp_mode, sp_size, sp_group, use_flash_attention):
from transformers.models.cohere.modeling_cohere import apply_rotary_pos_emb
def forward(
@ -692,7 +480,12 @@ def get_command_seq_parallel_attention_forward(sp_mode, sp_size, sp_group):
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if use_flash_attention:
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
else:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None: # no matter the length, we just slice it
@ -710,23 +503,20 @@ def get_command_seq_parallel_attention_forward(sp_mode, sp_size, sp_group):
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2)
else:
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
if not output_attentions or use_flash_attention:
attn_weights = None
return attn_output, attn_weights, past_key_value
return forward
def get_command_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
def get_command_seq_parallel_model_forward(sp_mode, sp_size, sp_group, use_flash_attention):
logger = logging.get_logger(__name__)
def forward(
@ -779,7 +569,17 @@ def get_command_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
if use_flash_attention:
hidden_states = inputs_embeds
mask_shape = (hidden_states.shape[0], 1, past_seen_tokens, past_seen_tokens)
attention_mask = ColoAttention.prepare_attn_kwargs(
mask_shape,
hidden_states.dtype,
hidden_states.device,
q_padding_mask=attention_mask,
is_causal=True,
)
else:
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
if sp_mode in ["ring", "split_gather"]:

24
colossalai/shardformer/policies/command.py

@ -19,8 +19,6 @@ from colossalai.shardformer.layer import (
from ..modeling.command import (
CommandPipelineForwards,
get_command_flash_attention_forward,
get_command_model_forward_for_flash_attn,
get_command_seq_parallel_attention_forward,
get_command_seq_parallel_model_forward,
get_lm_forward_with_dist_cross_entropy,
@ -95,7 +93,10 @@ class CommandPolicy(Policy):
self.append_or_create_method_replacement(
description={
"forward": get_command_seq_parallel_model_forward(
sp_mode=sp_mode, sp_size=sp_size, sp_group=sp_group
sp_mode=sp_mode,
sp_size=sp_size,
sp_group=sp_group,
use_flash_attention=use_flash_attention,
),
},
policy=policy,
@ -103,7 +104,9 @@ class CommandPolicy(Policy):
)
self.append_or_create_method_replacement(
description={
"forward": get_command_seq_parallel_attention_forward(sp_mode, sp_size, sp_group),
"forward": get_command_seq_parallel_attention_forward(
sp_mode, sp_size, sp_group, use_flash_attention=use_flash_attention
),
},
policy=policy,
target_key=attn_cls,
@ -120,7 +123,9 @@ class CommandPolicy(Policy):
)
self.append_or_create_method_replacement(
description={
"forward": get_command_seq_parallel_attention_forward(sp_mode, sp_size, sp_group),
"forward": get_command_seq_parallel_attention_forward(
sp_mode, sp_size, sp_group, use_flash_attention=use_flash_attention
),
},
policy=policy,
target_key=attn_cls,
@ -131,6 +136,7 @@ class CommandPolicy(Policy):
sp_mode=sp_mode,
sp_size=sp_size,
sp_group=sp_group,
use_flash_attention=use_flash_attention,
),
},
policy=policy,
@ -234,7 +240,9 @@ class CommandPolicy(Policy):
if use_flash_attention:
self.append_or_create_method_replacement(
description={
"forward": get_command_flash_attention_forward(self.shard_config, sp_mode, sp_group, sp_size),
"forward": get_command_seq_parallel_attention_forward(
sp_mode, sp_group, sp_size, use_flash_attention=use_flash_attention
),
},
policy=policy,
target_key=attn_cls,
@ -243,7 +251,9 @@ class CommandPolicy(Policy):
# replace Command model forward method
self.append_or_create_method_replacement(
description={
"forward": get_command_model_forward_for_flash_attn(self.shard_config),
"forward": get_command_seq_parallel_model_forward(
sp_mode, sp_size, sp_group, use_flash_attention=use_flash_attention
),
},
policy=policy,
target_key=CohereModel,

Loading…
Cancel
Save