change command

pull/5818/head
GuangyaoZhang 2024-06-14 03:04:56 +00:00
parent 0b81163bc0
commit f656d61778
9 changed files with 778 additions and 435 deletions

View File

@ -4,7 +4,7 @@ from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D
from .loss import cross_entropy_1d
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm, CohereLayerNorm, FusedCohereLayerNorm
from .parallel_module import ParallelModule
from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
@ -23,6 +23,8 @@ __all__ = [
"RMSNorm",
"FusedLayerNorm",
"FusedRMSNorm",
"CohereLayerNorm",
"FusedCohereLayerNorm",
"FusedLinear1D_Col",
"ParallelModule",
"PaddingEmbedding",

View File

@ -4,6 +4,7 @@ import warnings
from abc import ABC, abstractmethod
import torch.nn as nn
from transformers.models.cohere.modeling_cohere import CohereLayerNorm
from colossalai.lazy import LazyInitContext
@ -249,6 +250,115 @@ class FusedLayerNorm(BaseLayerNorm):
return layernorm
class CohereLayerNorm(BaseLayerNorm):
r"""
This is a wrapper around the transformers.models.cohere.CohereLayerNorm. It is meant to be used only with the from_native_module interface.
"""
def __init__(self) -> None:
raise NotImplementedError(
"CohereLayerNorm is not implemented as a physical class. "
"It is meant to be used only with the from_native_module interface to convert a transformers.models.cohere.CohereLayerNorm module to colossalai layer norm module."
)
@staticmethod
def from_native_module(module: CohereLayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
r"""
Convert a CohereLayerNorm module to colossalai layer norm module,
and optionally marking parameters for gradient aggregation.
Args:
module (transformers.models.cohere.CohereLayerNorm): The CohereLayerNorm module to be converted.
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
Returns:
nn.Module: The LayerNorm module.
Raises:
AssertionError: If the provided module is not an instance of CohereLayerNorm
"""
LazyInitContext.materialize(module)
if sp_partial_derived:
# Since gradients are computed using only a subset of the data,
# aggregation of these gradients is necessary during backpropagation.
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight)
return module
class FusedCohereLayerNorm(BaseLayerNorm):
r"""
This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface.
"""
def __init__(self) -> None:
raise NotImplementedError(
"FusedCohereLayerNorm is not implemented as a physical class. "
"It is meant to be used only with the from_native_module interface convert a transformers.models.cohere.CohereLayerNorm module to FusedLayerNorm module provided by apex."
)
@staticmethod
def from_native_module(module: CohereLayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
r"""
Convert a CohereLayerNorm module to FusedLayerNorm module provided by apex,
and optionally marking parameters for gradient aggregation.
Args:
module (transformers.models.cohere.CohereLayerNorm): The CohereLayerNorm module to be converted.
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
Returns:
nn.Module: Union[FastLayerNorm, FusedLayerNorm].
Raises:
AssertionError: If the provided module is not an instance of transformers.models.cohere.CohereLayerNorm.
"""
LazyInitContext.materialize(module)
# get the attributes of the module
normalized_shape = module.weight.size(0)
eps = module.variance_epsilon
elementwise_affine = True
dtype = module.weight.dtype
device = module.weight.device
# pick the suitable layernorm implementation
use_fast_ln = normalized_shape in FAST_LAYERNORM_SUPPORTED_SIZE
if use_fast_ln:
if EnableFastLayerNorm:
ApexFusedLayerNorm = FastLayerNormWithHook
else:
# fall back to the normal fused layernorm is not built
ApexFusedLayerNorm = FusedLayerNormWithHook
else:
try:
ApexFusedLayerNorm = FusedLayerNormWithHook
except NameError:
warnings.warn(
"Please install Apex from source to use fused kernels, or set self.enable_fused_normalization = False. Using vanilla layernorm instead."
)
return module
layernorm = (
ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device)
)
layernorm.weight = module.weight
if sp_partial_derived:
# Since gradients are computed using only a subset of the data,
# aggregation of these gradients is necessary during backpropagation.
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.weight)
SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.bias)
return layernorm
class FusedRMSNorm(BaseLayerNorm):
"""
This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface.

View File

@ -7,21 +7,16 @@ import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.cache_utils import Cache
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
)
from transformers.models.llama.modeling_llama import (
LlamaForCausalLM,
LlamaForSequenceClassification,
LlamaModel,
apply_rotary_pos_emb,
from transformers.models.cohere.modeling_cohere import (
CohereForCausalLM,
CohereModel,
StaticCache,
repeat_kv,
)
from transformers.utils import logging
@ -37,15 +32,15 @@ from colossalai.shardformer.shard import ShardConfig
from ..layer import ColoAttention, cross_entropy_1d
class LlamaPipelineForwards:
class CommandPipelineForwards:
"""
This class serves as a micro library for forward function substitution of Llama models
This class serves as a micro library for forward function substitution of Command models
under pipeline setting.
"""
@staticmethod
def llama_model_forward(
self: LlamaModel,
def command_model_forward(
self: CohereModel,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
@ -55,6 +50,7 @@ class LlamaPipelineForwards:
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
@ -68,6 +64,12 @@ class LlamaPipelineForwards:
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with pipeline parallelism. Setting `use_cache=False`..."
)
use_cache = False
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
@ -89,8 +91,17 @@ class LlamaPipelineForwards:
batch_size, seq_length = input_shape
device = hidden_states.device
seq_length_with_past = seq_length
past_key_values_length = 0
past_seen_tokens = 0
if use_cache: # kept for BC (cache positions)
if not isinstance(past_key_values, StaticCache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()
if cache_position is None:
if isinstance(past_key_values, StaticCache):
raise ValueError("cache_position is a required argument when using StaticCache.")
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=device)
seq_length_with_past = seq_length + past_seen_tokens
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions:
@ -103,18 +114,8 @@ class LlamaPipelineForwards:
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
use_cache = False
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0)
position_ids = cache_position.unsqueeze(0)
# embed positions, for the first stage, hidden_states is the input embeddings,
# for the other stages, hidden_states is the output of the previous stage
@ -129,28 +130,9 @@ class LlamaPipelineForwards:
is_causal=True,
)
else:
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._use_sdpa and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
hidden_states,
past_key_values_length,
)
attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position)
if self.gradient_checkpointing and self.training:
if self.gradient_checkpointing and self.training and use_cache:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
@ -190,6 +172,7 @@ class LlamaPipelineForwards:
past_key_values,
output_attentions,
use_cache,
cache_position,
)
else:
layer_outputs = decoder_layer(
@ -199,6 +182,7 @@ class LlamaPipelineForwards:
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = layer_outputs[0]
@ -237,8 +221,8 @@ class LlamaPipelineForwards:
return {"hidden_states": hidden_states}
@staticmethod
def llama_for_causal_lm_forward(
self: LlamaForCausalLM,
def command_for_causal_lm_forward(
self: CohereForCausalLM,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
@ -249,6 +233,7 @@ class LlamaPipelineForwards:
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
@ -266,9 +251,9 @@ class LlamaPipelineForwards:
Example:
```python
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> from transformers import AutoTokenizer, CohereForCausalLM
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> model = CohereForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
@ -295,7 +280,7 @@ class LlamaPipelineForwards:
output_hidden_states = False
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = LlamaPipelineForwards.llama_model_forward(
outputs = CommandPipelineForwards.command_model_forward(
self.model,
input_ids=input_ids,
attention_mask=attention_mask,
@ -306,6 +291,7 @@ class LlamaPipelineForwards:
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
@ -316,6 +302,8 @@ class LlamaPipelineForwards:
if stage_manager.is_last_stage():
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits * self.logit_scale
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
@ -355,137 +343,20 @@ class LlamaPipelineForwards:
hidden_states = outputs.get("hidden_states")
return {"hidden_states": hidden_states}
@staticmethod
def llama_for_sequence_classification_forward(
self: LlamaForSequenceClassification,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
):
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
logger = logging.get_logger(__name__)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions:
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False
if output_hidden_states:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
transformer_outputs = LlamaPipelineForwards.llama_model_forward(
self.model,
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
shard_config=shard_config,
)
if input_ids is not None:
batch_size = input_ids.shape[0]
elif inputs_embeds is not None:
batch_size = inputs_embeds.shape[0]
else:
batch_size = hidden_states.shape[0]
if stage_manager.is_last_stage():
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
else:
sequence_lengths = -1
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
else:
hidden_states = transformer_outputs.get("hidden_states")
return {"hidden_states": hidden_states}
def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size):
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
try:
from transformers.models.llama.modeling_llama import repeat_kv
except:
warnings.warn("using llamav1, llamav1 hasn't repeat_kv function")
def get_command_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size):
from transformers.models.cohere.modeling_cohere import CohereAttention, apply_rotary_pos_emb
from transformers.models.cohere.modeling_cohere import repeat_kv
def forward(
self: LlamaAttention,
self: CohereAttention,
hidden_states: torch.Tensor,
attention_mask: Optional[dict] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if "padding_mask" in kwargs:
@ -520,13 +391,14 @@ def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size):
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
@ -547,12 +419,12 @@ def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size):
return forward
def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig):
def get_command_model_forward_for_flash_attn(shard_config: ShardConfig):
logger = logging.get_logger(__name__)
assert shard_config.enable_flash_attention, "Flash Attention is not enabled."
def forward(
self: LlamaModel,
self: CohereModel,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
@ -562,6 +434,7 @@ def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -572,41 +445,40 @@ def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
past_seen_tokens = 0
if use_cache: # kept for BC (cache positions)
if not isinstance(past_key_values, StaticCache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()
if cache_position is None:
if isinstance(past_key_values, StaticCache):
raise ValueError("cache_position is a required argument when using StaticCache.")
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# embed positions
hidden_states = inputs_embeds
# in this case, attention_mask is a dict rather than a tensor
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
mask_shape = (hidden_states.shape[0], 1, past_seen_tokens, past_seen_tokens)
attention_mask = ColoAttention.prepare_attn_kwargs(
mask_shape,
hidden_states.dtype,
@ -625,43 +497,38 @@ def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig):
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
next_decoder_cache = None
for idx, decoder_layer in enumerate(self.layers):
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
@ -672,7 +539,11 @@ def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig):
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
next_cache = None
if use_cache:
next_cache = (
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
)
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
@ -686,10 +557,10 @@ def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig):
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
from transformers import LlamaForCausalLM
from transformers import CohereForCausalLM
def forward(
self: LlamaForCausalLM,
self: CohereForCausalLM,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
@ -700,6 +571,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@ -713,9 +585,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
Example:
```python
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> from transformers import AutoTokenizer, CohereForCausalLM
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> model = CohereForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
@ -744,15 +616,13 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0]
if self.config.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
logits = torch.cat(logits, dim=-1)
else:
logits = self.lm_head(hidden_states)
logits = self.lm_head(hidden_states)
logits = logits * self.logit_scale
logits = logits.float()
loss = None
@ -788,7 +658,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
return forward
def get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group):
def get_command_seq_parallel_attention_forward(sp_mode, sp_size, sp_group):
from transformers.models.cohere.modeling_cohere import apply_rotary_pos_emb
def forward(
self,
hidden_states: torch.Tensor,
@ -797,32 +669,16 @@ def get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group):
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
# sp: modify sp_len when sequence parallel mode is ring
if sp_mode in ["split_gather", "ring"]:
q_len *= sp_size
if self.config.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
query_states = torch.cat(query_states, dim=-1)
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
key_states = torch.cat(key_states, dim=-1)
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
@ -835,18 +691,14 @@ def get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
past_key_value = getattr(self, "past_key_value", past_key_value)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
@ -854,18 +706,9 @@ def get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group):
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
@ -885,12 +728,8 @@ def get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group):
else:
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
if self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
else:
attn_output = self.o_proj(attn_output)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
@ -899,11 +738,11 @@ def get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group):
return forward
def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
def get_command_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
logger = logging.get_logger(__name__)
def forward(
self,
self: CohereModel,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
@ -913,6 +752,7 @@ def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -924,56 +764,13 @@ def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
# modify past_key_values_length when using sequence parallel
past_key_values_length *= sp_size
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time, and must specify either one"
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if sp_mode in ["ring", "split_gather"]:
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
elif sp_mode == "all_to_all":
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device,
)
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, attention_mask.shape, inputs_embeds, past_key_values_length
)
hidden_states = inputs_embeds
if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training:
if use_cache:
logger.warning_once(
@ -981,6 +778,29 @@ def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
)
use_cache = False
past_seen_tokens = 0
if use_cache: # kept for BC (cache positions)
if not isinstance(past_key_values, StaticCache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()
if cache_position is None:
if isinstance(past_key_values, StaticCache):
raise ValueError("cache_position is a required argument when using StaticCache.")
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
if sp_mode in ["ring", "split_gather"]:
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
elif sp_mode == "all_to_all":
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
hidden_states = inputs_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
@ -990,14 +810,12 @@ def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, past_key_value, output_attentions)
return module(*inputs, past_key_value=past_key_values, output_attentions=output_attentions)
return custom_forward
@ -1013,15 +831,20 @@ def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
next_decoder_cache = (
next_decoder_cache.to_legacy_cache()
if isinstance(next_decoder_cache, Cache)
else next_decoder_cache
)
if output_attentions:
all_self_attns += (layer_outputs[1],)

View File

@ -192,6 +192,13 @@ _POLICY_LIST = {
"transformers.models.qwen2.modeling_qwen2.Qwen2ForSequenceClassification": PolicyLocation(
file_name="qwen2", class_name="Qwen2ForSequenceClassificationPolicy"
),
# Command-R
"transformers.models.cohere.modeling_cohere.CohereModel": PolicyLocation(
file_name="command", class_name="CommandModelPolicy"
),
"transformers.models.cohere.modeling_cohere.CohereForCausalLM": PolicyLocation(
file_name="command", class_name="CommandForCausalLMPolicy"
),
}

View File

@ -7,30 +7,30 @@ from torch import Tensor
from torch.nn import Module
from colossalai.shardformer.layer import (
FusedRMSNorm,
FusedCohereLayerNorm,
Linear1D_Col,
Linear1D_Row,
PaddingEmbedding,
PaddingLMHead,
RMSNorm,
CohereLayerNorm,
VocabParallelEmbedding1D,
VocabParallelLMHead1D,
)
from ..modeling.llama import (
LlamaPipelineForwards,
get_llama_flash_attention_forward,
get_llama_model_forward_for_flash_attn,
get_llama_seq_parallel_attention_forward,
get_llama_seq_parallel_model_forward,
from ..modeling.command import (
CommandPipelineForwards,
get_command_flash_attention_forward,
get_command_model_forward_for_flash_attn,
get_command_seq_parallel_attention_forward,
get_command_seq_parallel_model_forward,
get_lm_forward_with_dist_cross_entropy,
)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["LlamaPolicy", "LlamaForCausalLMPolicy", "LlamaForSequenceClassificationPolicy"]
__all__ = ["CommandPolicy", "CommandForCausalLMPolicy"]
class LlamaPolicy(Policy):
class CommandPolicy(Policy):
def config_sanity_check(self):
pass
@ -40,18 +40,18 @@ class LlamaPolicy(Policy):
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaFlashAttention2,
LlamaModel,
LlamaSdpaAttention,
from transformers.models.cohere.modeling_cohere import (
CohereAttention,
CohereDecoderLayer,
CohereFlashAttention2,
CohereModel,
CohereSdpaAttention,
)
ATTN_IMPLEMENTATION = {
"eager": LlamaAttention,
"flash_attention_2": LlamaFlashAttention2,
"sdpa": LlamaSdpaAttention,
"eager": CohereAttention,
"flash_attention_2": CohereFlashAttention2,
"sdpa": CohereSdpaAttention,
}
policy = {}
@ -64,16 +64,16 @@ class LlamaPolicy(Policy):
embedding_cls = PaddingEmbedding
if self.shard_config.enable_fused_normalization:
norm_cls = FusedRMSNorm
norm_cls = FusedCohereLayerNorm
else:
norm_cls = RMSNorm
norm_cls = CohereLayerNorm
if self.pipeline_stage_manager is not None:
self.shard_config.enable_sequence_parallelism = False
self.shard_config.enable_sequence_overlap = False
self.shard_config.sequence_parallelism_mode = None
warnings.warn(
f"For llama, sequence parallelism is currently not compatible with pipeline parallelism, set to be False"
f"For Command, sequence parallelism is currently not compatible with pipeline parallelism, set to be False"
)
sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None
sp_size = self.shard_config.sequence_parallel_size if self.shard_config.enable_sequence_parallelism else None
@ -94,16 +94,16 @@ class LlamaPolicy(Policy):
if sp_mode in ["split_gather", "ring"]:
self.append_or_create_method_replacement(
description={
"forward": get_llama_seq_parallel_model_forward(
"forward": get_command_seq_parallel_model_forward(
sp_mode=sp_mode, sp_size=sp_size, sp_group=sp_group
),
},
policy=policy,
target_key=LlamaModel,
target_key=CohereModel,
)
self.append_or_create_method_replacement(
description={
"forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group),
"forward": get_command_seq_parallel_attention_forward(sp_mode, sp_size, sp_group),
},
policy=policy,
target_key=attn_cls,
@ -120,21 +120,21 @@ class LlamaPolicy(Policy):
)
self.append_or_create_method_replacement(
description={
"forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group),
"forward": get_command_seq_parallel_attention_forward(sp_mode, sp_size, sp_group),
},
policy=policy,
target_key=attn_cls,
)
self.append_or_create_method_replacement(
description={
"forward": get_llama_seq_parallel_model_forward(
"forward": get_command_seq_parallel_model_forward(
sp_mode=sp_mode,
sp_size=sp_size,
sp_group=sp_group,
),
},
policy=policy,
target_key=LlamaModel,
target_key=CohereModel,
)
if self.shard_config.enable_tensor_parallelism:
@ -155,7 +155,7 @@ class LlamaPolicy(Policy):
self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size
)
policy[LlamaDecoderLayer] = ModulePolicyDescription(
policy[CohereDecoderLayer] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
sub_module_replacement=[
SubModuleReplacementDescription(
@ -204,7 +204,7 @@ class LlamaPolicy(Policy):
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
policy=policy,
target_key=LlamaModel,
target_key=CohereModel,
)
# optimization configuration
@ -215,14 +215,9 @@ class LlamaPolicy(Policy):
target_module=norm_cls,
kwargs={"sp_partial_derived": sp_partial_derived},
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=norm_cls,
kwargs={"sp_partial_derived": sp_partial_derived},
),
],
policy=policy,
target_key=LlamaDecoderLayer,
target_key=CohereDecoderLayer,
)
self.append_or_create_submodule_replacement(
@ -232,26 +227,26 @@ class LlamaPolicy(Policy):
kwargs={"sp_partial_derived": sp_partial_derived},
),
policy=policy,
target_key=LlamaModel,
target_key=CohereModel,
)
# use flash attention
if use_flash_attention:
self.append_or_create_method_replacement(
description={
"forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_group, sp_size),
"forward": get_command_flash_attention_forward(self.shard_config, sp_mode, sp_group, sp_size),
},
policy=policy,
target_key=attn_cls,
)
if self.pipeline_stage_manager is None:
# replace llama model forward method
# replace Command model forward method
self.append_or_create_method_replacement(
description={
"forward": get_llama_model_forward_for_flash_attn(self.shard_config),
"forward": get_command_model_forward_for_flash_attn(self.shard_config),
},
policy=policy,
target_key=LlamaModel,
target_key=CohereModel,
)
return policy
@ -266,7 +261,7 @@ class LlamaPolicy(Policy):
return
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == "LlamaModel":
if self.model.__class__.__name__ == "CohereModel":
module = self.model
else:
module = self.model.model
@ -293,7 +288,7 @@ class LlamaPolicy(Policy):
"""Get pipeline layers for current stage."""
assert self.pipeline_stage_manager is not None
if self.model.__class__.__name__ == "LlamaModel":
if self.model.__class__.__name__ == "CohereModel":
module = self.model
else:
module = self.model.model
@ -323,15 +318,15 @@ class LlamaPolicy(Policy):
return held_layers
class LlamaModelPolicy(LlamaPolicy):
class CommandModelPolicy(CommandPolicy):
def module_policy(self):
policy = super().module_policy()
from transformers.models.llama.modeling_llama import LlamaModel
from transformers.models.cohere.modeling_cohere import CohereModel
if self.pipeline_stage_manager:
# set None as default
self.set_pipeline_forward(
model_cls=LlamaModel, new_forward=LlamaPipelineForwards.llama_model_forward, policy=policy
model_cls=CohereModel, new_forward=CommandPipelineForwards.command_model_forward, policy=policy
)
return policy
@ -341,20 +336,20 @@ class LlamaModelPolicy(LlamaPolicy):
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in llama model"""
"""No shared params in command model"""
return []
class LlamaForCausalLMPolicy(LlamaPolicy):
class CommandForCausalLMPolicy(CommandPolicy):
def module_policy(self):
from transformers import LlamaForCausalLM
from transformers import CohereForCausalLM
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
new_item = {
LlamaForCausalLM: ModulePolicyDescription(
CohereForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
@ -368,12 +363,12 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
)
}
if self.shard_config.parallel_output:
new_item[LlamaForCausalLM].method_replacement = {
new_item[CohereForCausalLM].method_replacement = {
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
}
else:
new_item = {
LlamaForCausalLM: ModulePolicyDescription(
CohereForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
@ -388,7 +383,7 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
if self.pipeline_stage_manager:
# set None as default
self.set_pipeline_forward(
model_cls=LlamaForCausalLM, new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, policy=policy
model_cls=CohereForCausalLM, new_forward=CommandPipelineForwards.command_for_causal_lm_forward, policy=policy
)
return policy
@ -402,58 +397,17 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
llama_model = self.model.model
command_model = self.model.model
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
if (
id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
id(command_model.embed_tokens.weight) == id(self.model.lm_head.weight)
and self.pipeline_stage_manager.num_stages > 1
):
# tie weights
return [
{
0: llama_model.embed_tokens.weight,
0: command_model.embed_tokens.weight,
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
}
]
return []
class LlamaForSequenceClassificationPolicy(LlamaPolicy):
def module_policy(self):
from transformers import LlamaForSequenceClassification
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
# add a new item for sequence classification
new_item = {
LlamaForSequenceClassification: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
)
]
)
}
policy.update(new_item)
# to be confirmed
if self.pipeline_stage_manager:
# set None as default
self.set_pipeline_forward(
model_cls=LlamaForSequenceClassification,
new_forward=LlamaPipelineForwards.llama_for_sequence_classification_forward,
policy=policy,
)
return policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.score)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in llama for sequence classification model"""
return []
return []

59
diff.output Normal file
View File

@ -0,0 +1,59 @@
diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py
index 5aa21260..01453a05 100644
--- a/colossalai/shardformer/layer/normalization.py
+++ b/colossalai/shardformer/layer/normalization.py
@@ -165,7 +165,7 @@ class LayerNorm(BaseLayerNorm):
Raises:
AssertionError: If the provided module is not an instance of nn.LayerNorm.
"""
- assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm."
+ # assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm."
LazyInitContext.materialize(module)
@@ -174,7 +174,7 @@ class LayerNorm(BaseLayerNorm):
# aggregation of these gradients is necessary during backpropagation.
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight)
- SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias)
+ # SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias)
return module
@@ -209,9 +209,12 @@ class FusedLayerNorm(BaseLayerNorm):
LazyInitContext.materialize(module)
# get the attributes of the module
- normalized_shape = module.normalized_shape
- eps = module.eps
- elementwise_affine = module.elementwise_affine
+ # normalized_shape = module.normalized_shape
+ # eps = module.eps
+ # elementwise_affine = module.elementwise_affine
+ normalized_shape = module.weight.size(0)
+ eps = module.variance_epsilon
+ elementwise_affine = True
dtype = module.weight.dtype
device = module.weight.device
@@ -244,7 +247,7 @@ class FusedLayerNorm(BaseLayerNorm):
# aggregation of these gradients is necessary during backpropagation.
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.weight)
- SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.bias)
+ # SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.bias)
return layernorm
diff --git a/tests/test_shardformer/test_model/test_shard_command.py b/tests/test_shardformer/test_model/test_shard_command.py
index 6075f836..a7166e38 100644
--- a/tests/test_shardformer/test_model/test_shard_command.py
+++ b/tests/test_shardformer/test_model/test_shard_command.py
@@ -210,6 +210,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
],
)
def run_command_test(test_config):
+ print(test_config)
sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():

View File

@ -22,3 +22,9 @@ try:
from .qwen2 import *
except ImportError:
print("This version of transformers doesn't support qwen2.")
try:
from .command import *
except ImportError:
print("This version of transformers doesn't support Command-R.")

View File

@ -0,0 +1,81 @@
import torch
import transformers
from ..registry import ModelAttribute, model_zoo
try:
from transformers import CohereConfig
HAS_COMMAND = True
except ImportError:
HAS_COMMAND = False
if HAS_COMMAND:
# ===============================
# Register Command-R
# ===============================
def data_gen():
input_ids = torch.Tensor(
[
[1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082],
[1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082],
]
).long()
attention_mask = torch.Tensor(
[
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
]
).long()
return dict(input_ids=input_ids, attention_mask=attention_mask)
# label is needed for casual lm
def data_gen_for_casual_lm():
data = data_gen()
labels = data["input_ids"].clone()
data["labels"] = labels
return data
# transform the output to a dict
output_transform_fn = lambda x: x
# function to get the loss
loss_fn = lambda output: output["last_hidden_state"].mean()
loss_fn_for_casual_lm = lambda output: output["loss"]
loss_fn_for_seq_classification = lambda output: output["logits"].mean()
config = CohereConfig(
num_hidden_layers=8,
hidden_size=32,
intermediate_size=64,
num_attention_heads=4,
max_position_embeddings=128,
)
if hasattr(config, "pad_token_id"):
config.pad_token_id = config.eos_token_id
# register the following models
# transformers.CohereModel,
# transformers.CohereForCausalLM,
model_zoo.register(
name="transformers_command",
model_fn=lambda: transformers.CohereModel(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_command_for_casual_lm",
model_fn=lambda: transformers.CohereForCausalLM(config),
data_gen_fn=data_gen_for_casual_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_casual_lm,
model_attribute=ModelAttribute(has_control_flow=True),
)

View File

@ -0,0 +1,301 @@
import os
import pytest
import torch
import torch.distributed as dist
from torch.testing import assert_close
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import PipelineGradientCheckpointConfig
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
enable_gradient_checkpointing = test_config.pop("enable_gradient_checkpointing", False)
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
model_fn, loss_fn, test_config
)
if enable_gradient_checkpointing:
# org_model.gradient_checkpointing_enable()
sharded_model.unwrap().gradient_checkpointing_enable()
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
)
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# unwrap model
command_model = unwrap_model(org_model, "CohereModel", "model")
shard_command_model = unwrap_model(sharded_model, "CohereModel", "model")
row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"]
col_layer_for_check = ["layers[0].self_attn.o_proj"]
# Here we check the grad of layernorm because an all-reduce operation should be performed during sequence parallelism
norm_layer_for_check = ["layers[0].input_layernorm", "layers[1].input_layernorm"]
# During pipeline parallelism, we cannot get the grad of norm layer during first stage, so we only check this when pp is not enbaled
if stage_manager is None:
norm_layer_for_check.append("norm")
# Check the grad when using ZeRO-1 and ZeRO-2
if (
booster.plugin.zero_stage in [1, 2]
and booster.plugin.shard_config.enable_sequence_parallelism
and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all"
):
for p1, p2 in zip(command_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]):
working_p = sharded_optimizer._param_store.master_to_working_param[id(p2)]
grads = sharded_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(working_p))
grad_index = (
0 if sharded_optimizer._grad_store._partition_grads else sharded_optimizer._bucket_store.zero_local_rank
)
grad = grads[grad_index]
sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()]
assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False)
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {}
if (stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True)) and booster.plugin.zero_stage == 0:
if test_config["precision"] == "fp32":
atol, rtol = 1e-6, 1e-4
else:
atol, rtol = 5e-3, 5e-3
row_layer_grads = get_grad_tensors_for_check(
command_model, shard_command_model, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False
)
col_layer_grads = get_grad_tensors_for_check(
command_model, shard_command_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
)
norm_layer_grads = get_grad_tensors_for_check(
command_model,
shard_command_model,
norm_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False,
)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
grads_to_check.update(norm_layer_grads)
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage(ignore_chunk=True):
if test_config["precision"] == "fp32":
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == "CohereModel":
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# check weights
if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True):
if test_config["precision"] == "fp32":
atol, rtol = 1e-4, 1e-3
else:
atol, rtol = 5e-3, 5e-3
check_weight(
command_model, shard_command_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
)
# check grads
check_all_grad_tensors(grads_to_check)
torch.cuda.empty_cache()
@parameterize(
"test_config",
[
{
"tp_size": 2,
"pp_size": 1,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "ring",
"enable_flash_attention": True,
"use_lazy_init": True,
"zero_stage": 2,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 4,
"pp_size": 1,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "split_gather",
"enable_flash_attention": False,
"use_lazy_init": True,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 1,
"pp_size": 1,
"sp_size": 2,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
"use_lazy_init": True,
"zero_stage": 2,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 2,
"enable_all_optimization": True,
"use_lazy_init": True,
"precision": "fp16",
"initial_scale": 1,
"enable_gradient_checkpointing": True,
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(gradient_checkpointing_ratio=0.5),
},
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 4,
"use_lazy_init": False,
"precision": "fp32",
"enable_gradient_checkpointing": True,
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]),
},
{
"tp_size": 2,
"pp_size": 1,
"enable_all_optimization": True,
"use_lazy_init": True,
"zero_stage": 2,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 2,
"enable_all_optimization": True,
"use_lazy_init": True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
],
)
def run_command_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
@parameterize(
"test_config",
[
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp32",
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp16",
"zero_stage": 1,
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 2,
"pp_style": "interleaved",
"num_model_chunks": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"precision": "fp16",
"zero_stage": 1,
"initial_scale": 1,
"enable_gradient_checkpointing": True,
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(
num_ckpt_layers_per_stage=[0, 1, 2, 2],
),
},
],
)
def run_command_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
def check_command(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_command_test()
def check_command_3d(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_command_3d_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_command():
spawn(check_command, 4)
@pytest.mark.largedist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_command_3d():
spawn(check_command_3d, 8)
if __name__ == "__main__":
test_command()
test_command_3d()