mirror of https://github.com/hpcaitech/ColossalAI
change command
parent
0b81163bc0
commit
f656d61778
|
@ -4,7 +4,7 @@ from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
|
|||
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
|
||||
from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D
|
||||
from .loss import cross_entropy_1d
|
||||
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
|
||||
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm, CohereLayerNorm, FusedCohereLayerNorm
|
||||
from .parallel_module import ParallelModule
|
||||
from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
|
||||
|
||||
|
@ -23,6 +23,8 @@ __all__ = [
|
|||
"RMSNorm",
|
||||
"FusedLayerNorm",
|
||||
"FusedRMSNorm",
|
||||
"CohereLayerNorm",
|
||||
"FusedCohereLayerNorm",
|
||||
"FusedLinear1D_Col",
|
||||
"ParallelModule",
|
||||
"PaddingEmbedding",
|
||||
|
|
|
@ -4,6 +4,7 @@ import warnings
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers.models.cohere.modeling_cohere import CohereLayerNorm
|
||||
|
||||
from colossalai.lazy import LazyInitContext
|
||||
|
||||
|
@ -249,6 +250,115 @@ class FusedLayerNorm(BaseLayerNorm):
|
|||
return layernorm
|
||||
|
||||
|
||||
|
||||
class CohereLayerNorm(BaseLayerNorm):
|
||||
r"""
|
||||
This is a wrapper around the transformers.models.cohere.CohereLayerNorm. It is meant to be used only with the from_native_module interface.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"CohereLayerNorm is not implemented as a physical class. "
|
||||
"It is meant to be used only with the from_native_module interface to convert a transformers.models.cohere.CohereLayerNorm module to colossalai layer norm module."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: CohereLayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
|
||||
r"""
|
||||
Convert a CohereLayerNorm module to colossalai layer norm module,
|
||||
and optionally marking parameters for gradient aggregation.
|
||||
|
||||
Args:
|
||||
module (transformers.models.cohere.CohereLayerNorm): The CohereLayerNorm module to be converted.
|
||||
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
|
||||
|
||||
Returns:
|
||||
nn.Module: The LayerNorm module.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the provided module is not an instance of CohereLayerNorm
|
||||
"""
|
||||
|
||||
LazyInitContext.materialize(module)
|
||||
|
||||
if sp_partial_derived:
|
||||
# Since gradients are computed using only a subset of the data,
|
||||
# aggregation of these gradients is necessary during backpropagation.
|
||||
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
|
||||
SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight)
|
||||
|
||||
return module
|
||||
|
||||
|
||||
class FusedCohereLayerNorm(BaseLayerNorm):
|
||||
r"""
|
||||
This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"FusedCohereLayerNorm is not implemented as a physical class. "
|
||||
"It is meant to be used only with the from_native_module interface convert a transformers.models.cohere.CohereLayerNorm module to FusedLayerNorm module provided by apex."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: CohereLayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
|
||||
r"""
|
||||
Convert a CohereLayerNorm module to FusedLayerNorm module provided by apex,
|
||||
and optionally marking parameters for gradient aggregation.
|
||||
|
||||
Args:
|
||||
module (transformers.models.cohere.CohereLayerNorm): The CohereLayerNorm module to be converted.
|
||||
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
|
||||
|
||||
Returns:
|
||||
nn.Module: Union[FastLayerNorm, FusedLayerNorm].
|
||||
|
||||
Raises:
|
||||
AssertionError: If the provided module is not an instance of transformers.models.cohere.CohereLayerNorm.
|
||||
"""
|
||||
|
||||
LazyInitContext.materialize(module)
|
||||
# get the attributes of the module
|
||||
normalized_shape = module.weight.size(0)
|
||||
eps = module.variance_epsilon
|
||||
elementwise_affine = True
|
||||
dtype = module.weight.dtype
|
||||
device = module.weight.device
|
||||
|
||||
# pick the suitable layernorm implementation
|
||||
use_fast_ln = normalized_shape in FAST_LAYERNORM_SUPPORTED_SIZE
|
||||
|
||||
if use_fast_ln:
|
||||
if EnableFastLayerNorm:
|
||||
ApexFusedLayerNorm = FastLayerNormWithHook
|
||||
else:
|
||||
# fall back to the normal fused layernorm is not built
|
||||
ApexFusedLayerNorm = FusedLayerNormWithHook
|
||||
else:
|
||||
try:
|
||||
ApexFusedLayerNorm = FusedLayerNormWithHook
|
||||
except NameError:
|
||||
warnings.warn(
|
||||
"Please install Apex from source to use fused kernels, or set self.enable_fused_normalization = False. Using vanilla layernorm instead."
|
||||
)
|
||||
return module
|
||||
|
||||
layernorm = (
|
||||
ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device)
|
||||
)
|
||||
layernorm.weight = module.weight
|
||||
|
||||
if sp_partial_derived:
|
||||
# Since gradients are computed using only a subset of the data,
|
||||
# aggregation of these gradients is necessary during backpropagation.
|
||||
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
|
||||
SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.weight)
|
||||
SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.bias)
|
||||
|
||||
return layernorm
|
||||
|
||||
|
||||
class FusedRMSNorm(BaseLayerNorm):
|
||||
"""
|
||||
This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface.
|
||||
|
|
|
@ -7,21 +7,16 @@ import torch.nn.functional as F
|
|||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_attn_mask_utils import (
|
||||
_prepare_4d_causal_attention_mask,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
from transformers.cache_utils import Cache, DynamicCache
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
SequenceClassifierOutputWithPast,
|
||||
)
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaForCausalLM,
|
||||
LlamaForSequenceClassification,
|
||||
LlamaModel,
|
||||
apply_rotary_pos_emb,
|
||||
from transformers.models.cohere.modeling_cohere import (
|
||||
CohereForCausalLM,
|
||||
CohereModel,
|
||||
StaticCache,
|
||||
repeat_kv,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
|
@ -37,15 +32,15 @@ from colossalai.shardformer.shard import ShardConfig
|
|||
from ..layer import ColoAttention, cross_entropy_1d
|
||||
|
||||
|
||||
class LlamaPipelineForwards:
|
||||
class CommandPipelineForwards:
|
||||
"""
|
||||
This class serves as a micro library for forward function substitution of Llama models
|
||||
This class serves as a micro library for forward function substitution of Command models
|
||||
under pipeline setting.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def llama_model_forward(
|
||||
self: LlamaModel,
|
||||
def command_model_forward(
|
||||
self: CohereModel,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
|
@ -55,6 +50,7 @@ class LlamaPipelineForwards:
|
|||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
|
@ -68,6 +64,12 @@ class LlamaPipelineForwards:
|
|||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with pipeline parallelism. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
|
@ -89,8 +91,17 @@ class LlamaPipelineForwards:
|
|||
batch_size, seq_length = input_shape
|
||||
device = hidden_states.device
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
past_seen_tokens = 0
|
||||
if use_cache: # kept for BC (cache positions)
|
||||
if not isinstance(past_key_values, StaticCache):
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
past_seen_tokens = past_key_values.get_seq_length()
|
||||
if cache_position is None:
|
||||
if isinstance(past_key_values, StaticCache):
|
||||
raise ValueError("cache_position is a required argument when using StaticCache.")
|
||||
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=device)
|
||||
|
||||
seq_length_with_past = seq_length + past_seen_tokens
|
||||
|
||||
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||
if output_attentions:
|
||||
|
@ -103,18 +114,8 @@ class LlamaPipelineForwards:
|
|||
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
|
||||
use_cache = False
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0)
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
# embed positions, for the first stage, hidden_states is the input embeddings,
|
||||
# for the other stages, hidden_states is the output of the previous stage
|
||||
|
@ -129,28 +130,9 @@ class LlamaPipelineForwards:
|
|||
is_causal=True,
|
||||
)
|
||||
else:
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
elif self._use_sdpa and not output_attentions:
|
||||
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
||||
# the manual implementation that requires a 4D causal mask in all cases.
|
||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
)
|
||||
else:
|
||||
# 4d mask is passed through the layers
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
hidden_states,
|
||||
past_key_values_length,
|
||||
)
|
||||
attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
|
@ -190,6 +172,7 @@ class LlamaPipelineForwards:
|
|||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
|
@ -199,6 +182,7 @@ class LlamaPipelineForwards:
|
|||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
@ -237,8 +221,8 @@ class LlamaPipelineForwards:
|
|||
return {"hidden_states": hidden_states}
|
||||
|
||||
@staticmethod
|
||||
def llama_for_causal_lm_forward(
|
||||
self: LlamaForCausalLM,
|
||||
def command_for_causal_lm_forward(
|
||||
self: CohereForCausalLM,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
|
@ -249,6 +233,7 @@ class LlamaPipelineForwards:
|
|||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
|
@ -266,9 +251,9 @@ class LlamaPipelineForwards:
|
|||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
||||
>>> from transformers import AutoTokenizer, CohereForCausalLM
|
||||
|
||||
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||
>>> model = CohereForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
|
@ -295,7 +280,7 @@ class LlamaPipelineForwards:
|
|||
output_hidden_states = False
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = LlamaPipelineForwards.llama_model_forward(
|
||||
outputs = CommandPipelineForwards.command_model_forward(
|
||||
self.model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
|
@ -306,6 +291,7 @@ class LlamaPipelineForwards:
|
|||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
|
@ -316,6 +302,8 @@ class LlamaPipelineForwards:
|
|||
if stage_manager.is_last_stage():
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits * self.logit_scale
|
||||
logits = logits.float()
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
|
@ -355,137 +343,20 @@ class LlamaPipelineForwards:
|
|||
hidden_states = outputs.get("hidden_states")
|
||||
return {"hidden_states": hidden_states}
|
||||
|
||||
@staticmethod
|
||||
def llama_for_sequence_classification_forward(
|
||||
self: LlamaForSequenceClassification,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
):
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
logger = logging.get_logger(__name__)
|
||||
def get_command_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size):
|
||||
from transformers.models.cohere.modeling_cohere import CohereAttention, apply_rotary_pos_emb
|
||||
from transformers.models.cohere.modeling_cohere import repeat_kv
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||
if output_attentions:
|
||||
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||
output_hidden_states = False
|
||||
|
||||
transformer_outputs = LlamaPipelineForwards.llama_model_forward(
|
||||
self.model,
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config,
|
||||
)
|
||||
|
||||
if input_ids is not None:
|
||||
batch_size = input_ids.shape[0]
|
||||
elif inputs_embeds is not None:
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
else:
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = transformer_outputs[0]
|
||||
logits = self.score(hidden_states)
|
||||
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
else:
|
||||
if input_ids is not None:
|
||||
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
labels = labels.to(logits.device)
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
if self.num_labels == 1:
|
||||
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
if not return_dict:
|
||||
output = (pooled_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return SequenceClassifierOutputWithPast(
|
||||
loss=loss,
|
||||
logits=pooled_logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = transformer_outputs.get("hidden_states")
|
||||
return {"hidden_states": hidden_states}
|
||||
|
||||
|
||||
def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size):
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
||||
|
||||
try:
|
||||
from transformers.models.llama.modeling_llama import repeat_kv
|
||||
except:
|
||||
warnings.warn("using llamav1, llamav1 hasn't repeat_kv function")
|
||||
|
||||
def forward(
|
||||
self: LlamaAttention,
|
||||
self: CohereAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[dict] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if "padding_mask" in kwargs:
|
||||
|
@ -520,13 +391,14 @@ def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size):
|
|||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
|
@ -547,12 +419,12 @@ def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size):
|
|||
return forward
|
||||
|
||||
|
||||
def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig):
|
||||
def get_command_model_forward_for_flash_attn(shard_config: ShardConfig):
|
||||
logger = logging.get_logger(__name__)
|
||||
assert shard_config.enable_flash_attention, "Flash Attention is not enabled."
|
||||
|
||||
def forward(
|
||||
self: LlamaModel,
|
||||
self: CohereModel,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
|
@ -562,6 +434,7 @@ def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig):
|
|||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
|
@ -572,41 +445,40 @@ def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig):
|
|||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError(
|
||||
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
past_seen_tokens = 0
|
||||
if use_cache: # kept for BC (cache positions)
|
||||
if not isinstance(past_key_values, StaticCache):
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
past_seen_tokens = past_key_values.get_seq_length()
|
||||
if cache_position is None:
|
||||
if isinstance(past_key_values, StaticCache):
|
||||
raise ValueError("cache_position is a required argument when using StaticCache.")
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# in this case, attention_mask is a dict rather than a tensor
|
||||
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
|
||||
mask_shape = (hidden_states.shape[0], 1, past_seen_tokens, past_seen_tokens)
|
||||
attention_mask = ColoAttention.prepare_attn_kwargs(
|
||||
mask_shape,
|
||||
hidden_states.dtype,
|
||||
|
@ -625,43 +497,38 @@ def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig):
|
|||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
next_decoder_cache = None
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, past_key_value, output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
@ -672,7 +539,11 @@ def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig):
|
|||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
next_cache = None
|
||||
if use_cache:
|
||||
next_cache = (
|
||||
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
|
||||
)
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
|
@ -686,10 +557,10 @@ def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig):
|
|||
|
||||
|
||||
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||
from transformers import LlamaForCausalLM
|
||||
from transformers import CohereForCausalLM
|
||||
|
||||
def forward(
|
||||
self: LlamaForCausalLM,
|
||||
self: CohereForCausalLM,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
|
@ -700,6 +571,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
|
@ -713,9 +585,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
||||
>>> from transformers import AutoTokenizer, CohereForCausalLM
|
||||
|
||||
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||
>>> model = CohereForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
|
@ -744,15 +616,13 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if self.config.pretraining_tp > 1:
|
||||
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
|
||||
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
logits = torch.cat(logits, dim=-1)
|
||||
else:
|
||||
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits * self.logit_scale
|
||||
logits = logits.float()
|
||||
|
||||
loss = None
|
||||
|
@ -788,7 +658,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||
return forward
|
||||
|
||||
|
||||
def get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group):
|
||||
def get_command_seq_parallel_attention_forward(sp_mode, sp_size, sp_group):
|
||||
from transformers.models.cohere.modeling_cohere import apply_rotary_pos_emb
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
@ -797,29 +669,13 @@ def get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group):
|
|||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
# sp: modify sp_len when sequence parallel mode is ring
|
||||
if sp_mode in ["split_gather", "ring"]:
|
||||
q_len *= sp_size
|
||||
if self.config.pretraining_tp > 1:
|
||||
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
||||
query_slices = self.q_proj.weight.split(
|
||||
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
||||
)
|
||||
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
||||
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
||||
|
||||
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
query_states = torch.cat(query_states, dim=-1)
|
||||
|
||||
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
key_states = torch.cat(key_states, dim=-1)
|
||||
|
||||
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
value_states = torch.cat(value_states, dim=-1)
|
||||
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
@ -835,18 +691,14 @@ def get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group):
|
|||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
past_key_value = getattr(self, "past_key_value", past_key_value)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
|
@ -854,18 +706,9 @@ def get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group):
|
|||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
|
@ -885,11 +728,7 @@ def get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group):
|
|||
else:
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
||||
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
||||
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
|
||||
else:
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
|
@ -899,11 +738,11 @@ def get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group):
|
|||
return forward
|
||||
|
||||
|
||||
def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
|
||||
def get_command_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
self: CohereModel,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
|
@ -913,6 +752,7 @@ def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
|
|||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
|
@ -924,56 +764,13 @@ def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
|
|||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
# modify past_key_values_length when using sequence parallel
|
||||
past_key_values_length *= sp_size
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
raise ValueError(
|
||||
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if sp_mode in ["ring", "split_gather"]:
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
|
||||
elif sp_mode == "all_to_all":
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past),
|
||||
dtype=torch.bool,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, attention_mask.shape, inputs_embeds, past_key_values_length
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
|
@ -981,6 +778,29 @@ def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
|
|||
)
|
||||
use_cache = False
|
||||
|
||||
past_seen_tokens = 0
|
||||
if use_cache: # kept for BC (cache positions)
|
||||
if not isinstance(past_key_values, StaticCache):
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
past_seen_tokens = past_key_values.get_seq_length()
|
||||
if cache_position is None:
|
||||
if isinstance(past_key_values, StaticCache):
|
||||
raise ValueError("cache_position is a required argument when using StaticCache.")
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
|
||||
|
||||
if sp_mode in ["ring", "split_gather"]:
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
|
||||
elif sp_mode == "all_to_all":
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
@ -990,14 +810,12 @@ def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
|
|||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, past_key_value, output_attentions)
|
||||
return module(*inputs, past_key_value=past_key_values, output_attentions=output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
|
@ -1013,15 +831,20 @@ def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
|
|||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
next_decoder_cache = (
|
||||
next_decoder_cache.to_legacy_cache()
|
||||
if isinstance(next_decoder_cache, Cache)
|
||||
else next_decoder_cache
|
||||
)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
|
|
@ -192,6 +192,13 @@ _POLICY_LIST = {
|
|||
"transformers.models.qwen2.modeling_qwen2.Qwen2ForSequenceClassification": PolicyLocation(
|
||||
file_name="qwen2", class_name="Qwen2ForSequenceClassificationPolicy"
|
||||
),
|
||||
# Command-R
|
||||
"transformers.models.cohere.modeling_cohere.CohereModel": PolicyLocation(
|
||||
file_name="command", class_name="CommandModelPolicy"
|
||||
),
|
||||
"transformers.models.cohere.modeling_cohere.CohereForCausalLM": PolicyLocation(
|
||||
file_name="command", class_name="CommandForCausalLMPolicy"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -7,30 +7,30 @@ from torch import Tensor
|
|||
from torch.nn import Module
|
||||
|
||||
from colossalai.shardformer.layer import (
|
||||
FusedRMSNorm,
|
||||
FusedCohereLayerNorm,
|
||||
Linear1D_Col,
|
||||
Linear1D_Row,
|
||||
PaddingEmbedding,
|
||||
PaddingLMHead,
|
||||
RMSNorm,
|
||||
CohereLayerNorm,
|
||||
VocabParallelEmbedding1D,
|
||||
VocabParallelLMHead1D,
|
||||
)
|
||||
|
||||
from ..modeling.llama import (
|
||||
LlamaPipelineForwards,
|
||||
get_llama_flash_attention_forward,
|
||||
get_llama_model_forward_for_flash_attn,
|
||||
get_llama_seq_parallel_attention_forward,
|
||||
get_llama_seq_parallel_model_forward,
|
||||
from ..modeling.command import (
|
||||
CommandPipelineForwards,
|
||||
get_command_flash_attention_forward,
|
||||
get_command_model_forward_for_flash_attn,
|
||||
get_command_seq_parallel_attention_forward,
|
||||
get_command_seq_parallel_model_forward,
|
||||
get_lm_forward_with_dist_cross_entropy,
|
||||
)
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ["LlamaPolicy", "LlamaForCausalLMPolicy", "LlamaForSequenceClassificationPolicy"]
|
||||
__all__ = ["CommandPolicy", "CommandForCausalLMPolicy"]
|
||||
|
||||
|
||||
class LlamaPolicy(Policy):
|
||||
class CommandPolicy(Policy):
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
|
||||
|
@ -40,18 +40,18 @@ class LlamaPolicy(Policy):
|
|||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
LlamaDecoderLayer,
|
||||
LlamaFlashAttention2,
|
||||
LlamaModel,
|
||||
LlamaSdpaAttention,
|
||||
from transformers.models.cohere.modeling_cohere import (
|
||||
CohereAttention,
|
||||
CohereDecoderLayer,
|
||||
CohereFlashAttention2,
|
||||
CohereModel,
|
||||
CohereSdpaAttention,
|
||||
)
|
||||
|
||||
ATTN_IMPLEMENTATION = {
|
||||
"eager": LlamaAttention,
|
||||
"flash_attention_2": LlamaFlashAttention2,
|
||||
"sdpa": LlamaSdpaAttention,
|
||||
"eager": CohereAttention,
|
||||
"flash_attention_2": CohereFlashAttention2,
|
||||
"sdpa": CohereSdpaAttention,
|
||||
}
|
||||
policy = {}
|
||||
|
||||
|
@ -64,16 +64,16 @@ class LlamaPolicy(Policy):
|
|||
embedding_cls = PaddingEmbedding
|
||||
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
norm_cls = FusedRMSNorm
|
||||
norm_cls = FusedCohereLayerNorm
|
||||
else:
|
||||
norm_cls = RMSNorm
|
||||
norm_cls = CohereLayerNorm
|
||||
|
||||
if self.pipeline_stage_manager is not None:
|
||||
self.shard_config.enable_sequence_parallelism = False
|
||||
self.shard_config.enable_sequence_overlap = False
|
||||
self.shard_config.sequence_parallelism_mode = None
|
||||
warnings.warn(
|
||||
f"For llama, sequence parallelism is currently not compatible with pipeline parallelism, set to be False"
|
||||
f"For Command, sequence parallelism is currently not compatible with pipeline parallelism, set to be False"
|
||||
)
|
||||
sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None
|
||||
sp_size = self.shard_config.sequence_parallel_size if self.shard_config.enable_sequence_parallelism else None
|
||||
|
@ -94,16 +94,16 @@ class LlamaPolicy(Policy):
|
|||
if sp_mode in ["split_gather", "ring"]:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_llama_seq_parallel_model_forward(
|
||||
"forward": get_command_seq_parallel_model_forward(
|
||||
sp_mode=sp_mode, sp_size=sp_size, sp_group=sp_group
|
||||
),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=LlamaModel,
|
||||
target_key=CohereModel,
|
||||
)
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group),
|
||||
"forward": get_command_seq_parallel_attention_forward(sp_mode, sp_size, sp_group),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=attn_cls,
|
||||
|
@ -120,21 +120,21 @@ class LlamaPolicy(Policy):
|
|||
)
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group),
|
||||
"forward": get_command_seq_parallel_attention_forward(sp_mode, sp_size, sp_group),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=attn_cls,
|
||||
)
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_llama_seq_parallel_model_forward(
|
||||
"forward": get_command_seq_parallel_model_forward(
|
||||
sp_mode=sp_mode,
|
||||
sp_size=sp_size,
|
||||
sp_group=sp_group,
|
||||
),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=LlamaModel,
|
||||
target_key=CohereModel,
|
||||
)
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
|
@ -155,7 +155,7 @@ class LlamaPolicy(Policy):
|
|||
self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size
|
||||
)
|
||||
|
||||
policy[LlamaDecoderLayer] = ModulePolicyDescription(
|
||||
policy[CohereDecoderLayer] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -204,7 +204,7 @@ class LlamaPolicy(Policy):
|
|||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
),
|
||||
policy=policy,
|
||||
target_key=LlamaModel,
|
||||
target_key=CohereModel,
|
||||
)
|
||||
|
||||
# optimization configuration
|
||||
|
@ -215,14 +215,9 @@ class LlamaPolicy(Policy):
|
|||
target_module=norm_cls,
|
||||
kwargs={"sp_partial_derived": sp_partial_derived},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="post_attention_layernorm",
|
||||
target_module=norm_cls,
|
||||
kwargs={"sp_partial_derived": sp_partial_derived},
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=LlamaDecoderLayer,
|
||||
target_key=CohereDecoderLayer,
|
||||
)
|
||||
|
||||
self.append_or_create_submodule_replacement(
|
||||
|
@ -232,26 +227,26 @@ class LlamaPolicy(Policy):
|
|||
kwargs={"sp_partial_derived": sp_partial_derived},
|
||||
),
|
||||
policy=policy,
|
||||
target_key=LlamaModel,
|
||||
target_key=CohereModel,
|
||||
)
|
||||
|
||||
# use flash attention
|
||||
if use_flash_attention:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_group, sp_size),
|
||||
"forward": get_command_flash_attention_forward(self.shard_config, sp_mode, sp_group, sp_size),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=attn_cls,
|
||||
)
|
||||
if self.pipeline_stage_manager is None:
|
||||
# replace llama model forward method
|
||||
# replace Command model forward method
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_llama_model_forward_for_flash_attn(self.shard_config),
|
||||
"forward": get_command_model_forward_for_flash_attn(self.shard_config),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=LlamaModel,
|
||||
target_key=CohereModel,
|
||||
)
|
||||
|
||||
return policy
|
||||
|
@ -266,7 +261,7 @@ class LlamaPolicy(Policy):
|
|||
return
|
||||
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if self.model.__class__.__name__ == "LlamaModel":
|
||||
if self.model.__class__.__name__ == "CohereModel":
|
||||
module = self.model
|
||||
else:
|
||||
module = self.model.model
|
||||
|
@ -293,7 +288,7 @@ class LlamaPolicy(Policy):
|
|||
"""Get pipeline layers for current stage."""
|
||||
assert self.pipeline_stage_manager is not None
|
||||
|
||||
if self.model.__class__.__name__ == "LlamaModel":
|
||||
if self.model.__class__.__name__ == "CohereModel":
|
||||
module = self.model
|
||||
else:
|
||||
module = self.model.model
|
||||
|
@ -323,15 +318,15 @@ class LlamaPolicy(Policy):
|
|||
return held_layers
|
||||
|
||||
|
||||
class LlamaModelPolicy(LlamaPolicy):
|
||||
class CommandModelPolicy(CommandPolicy):
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
from transformers.models.llama.modeling_llama import LlamaModel
|
||||
from transformers.models.cohere.modeling_cohere import CohereModel
|
||||
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
self.set_pipeline_forward(
|
||||
model_cls=LlamaModel, new_forward=LlamaPipelineForwards.llama_model_forward, policy=policy
|
||||
model_cls=CohereModel, new_forward=CommandPipelineForwards.command_model_forward, policy=policy
|
||||
)
|
||||
return policy
|
||||
|
||||
|
@ -341,20 +336,20 @@ class LlamaModelPolicy(LlamaPolicy):
|
|||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
"""No shared params in llama model"""
|
||||
"""No shared params in command model"""
|
||||
return []
|
||||
|
||||
|
||||
class LlamaForCausalLMPolicy(LlamaPolicy):
|
||||
class CommandForCausalLMPolicy(CommandPolicy):
|
||||
def module_policy(self):
|
||||
from transformers import LlamaForCausalLM
|
||||
from transformers import CohereForCausalLM
|
||||
|
||||
policy = super().module_policy()
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for casual lm
|
||||
new_item = {
|
||||
LlamaForCausalLM: ModulePolicyDescription(
|
||||
CohereForCausalLM: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
|
@ -368,12 +363,12 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
|||
)
|
||||
}
|
||||
if self.shard_config.parallel_output:
|
||||
new_item[LlamaForCausalLM].method_replacement = {
|
||||
new_item[CohereForCausalLM].method_replacement = {
|
||||
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
|
||||
}
|
||||
else:
|
||||
new_item = {
|
||||
LlamaForCausalLM: ModulePolicyDescription(
|
||||
CohereForCausalLM: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
|
@ -388,7 +383,7 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
|||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
self.set_pipeline_forward(
|
||||
model_cls=LlamaForCausalLM, new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, policy=policy
|
||||
model_cls=CohereForCausalLM, new_forward=CommandPipelineForwards.command_for_causal_lm_forward, policy=policy
|
||||
)
|
||||
|
||||
return policy
|
||||
|
@ -402,58 +397,17 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
|||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
llama_model = self.model.model
|
||||
command_model = self.model.model
|
||||
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
|
||||
if (
|
||||
id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
|
||||
id(command_model.embed_tokens.weight) == id(self.model.lm_head.weight)
|
||||
and self.pipeline_stage_manager.num_stages > 1
|
||||
):
|
||||
# tie weights
|
||||
return [
|
||||
{
|
||||
0: llama_model.embed_tokens.weight,
|
||||
0: command_model.embed_tokens.weight,
|
||||
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
|
||||
}
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
||||
def module_policy(self):
|
||||
from transformers import LlamaForSequenceClassification
|
||||
|
||||
policy = super().module_policy()
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for sequence classification
|
||||
new_item = {
|
||||
LlamaForSequenceClassification: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
|
||||
)
|
||||
]
|
||||
)
|
||||
}
|
||||
policy.update(new_item)
|
||||
# to be confirmed
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
self.set_pipeline_forward(
|
||||
model_cls=LlamaForSequenceClassification,
|
||||
new_forward=LlamaPipelineForwards.llama_for_sequence_classification_forward,
|
||||
policy=policy,
|
||||
)
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.score)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
"""No shared params in llama for sequence classification model"""
|
||||
return []
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py
|
||||
index 5aa21260..01453a05 100644
|
||||
--- a/colossalai/shardformer/layer/normalization.py
|
||||
+++ b/colossalai/shardformer/layer/normalization.py
|
||||
@@ -165,7 +165,7 @@ class LayerNorm(BaseLayerNorm):
|
||||
Raises:
|
||||
AssertionError: If the provided module is not an instance of nn.LayerNorm.
|
||||
"""
|
||||
- assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm."
|
||||
+ # assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm."
|
||||
|
||||
LazyInitContext.materialize(module)
|
||||
|
||||
@@ -174,7 +174,7 @@ class LayerNorm(BaseLayerNorm):
|
||||
# aggregation of these gradients is necessary during backpropagation.
|
||||
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
|
||||
SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight)
|
||||
- SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias)
|
||||
+ # SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias)
|
||||
|
||||
return module
|
||||
|
||||
@@ -209,9 +209,12 @@ class FusedLayerNorm(BaseLayerNorm):
|
||||
|
||||
LazyInitContext.materialize(module)
|
||||
# get the attributes of the module
|
||||
- normalized_shape = module.normalized_shape
|
||||
- eps = module.eps
|
||||
- elementwise_affine = module.elementwise_affine
|
||||
+ # normalized_shape = module.normalized_shape
|
||||
+ # eps = module.eps
|
||||
+ # elementwise_affine = module.elementwise_affine
|
||||
+ normalized_shape = module.weight.size(0)
|
||||
+ eps = module.variance_epsilon
|
||||
+ elementwise_affine = True
|
||||
dtype = module.weight.dtype
|
||||
device = module.weight.device
|
||||
|
||||
@@ -244,7 +247,7 @@ class FusedLayerNorm(BaseLayerNorm):
|
||||
# aggregation of these gradients is necessary during backpropagation.
|
||||
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
|
||||
SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.weight)
|
||||
- SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.bias)
|
||||
+ # SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.bias)
|
||||
|
||||
return layernorm
|
||||
|
||||
diff --git a/tests/test_shardformer/test_model/test_shard_command.py b/tests/test_shardformer/test_model/test_shard_command.py
|
||||
index 6075f836..a7166e38 100644
|
||||
--- a/tests/test_shardformer/test_model/test_shard_command.py
|
||||
+++ b/tests/test_shardformer/test_model/test_shard_command.py
|
||||
@@ -210,6 +210,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
],
|
||||
)
|
||||
def run_command_test(test_config):
|
||||
+ print(test_config)
|
||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm")
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
|
@ -22,3 +22,9 @@ try:
|
|||
from .qwen2 import *
|
||||
except ImportError:
|
||||
print("This version of transformers doesn't support qwen2.")
|
||||
|
||||
|
||||
try:
|
||||
from .command import *
|
||||
except ImportError:
|
||||
print("This version of transformers doesn't support Command-R.")
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
import torch
|
||||
import transformers
|
||||
|
||||
from ..registry import ModelAttribute, model_zoo
|
||||
|
||||
try:
|
||||
from transformers import CohereConfig
|
||||
|
||||
HAS_COMMAND = True
|
||||
except ImportError:
|
||||
HAS_COMMAND = False
|
||||
|
||||
if HAS_COMMAND:
|
||||
# ===============================
|
||||
# Register Command-R
|
||||
# ===============================
|
||||
|
||||
def data_gen():
|
||||
|
||||
|
||||
input_ids = torch.Tensor(
|
||||
[
|
||||
[1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082],
|
||||
[1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082],
|
||||
]
|
||||
).long()
|
||||
|
||||
attention_mask = torch.Tensor(
|
||||
[
|
||||
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||
]
|
||||
).long()
|
||||
|
||||
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
# label is needed for casual lm
|
||||
def data_gen_for_casual_lm():
|
||||
data = data_gen()
|
||||
labels = data["input_ids"].clone()
|
||||
data["labels"] = labels
|
||||
return data
|
||||
|
||||
# transform the output to a dict
|
||||
output_transform_fn = lambda x: x
|
||||
|
||||
# function to get the loss
|
||||
loss_fn = lambda output: output["last_hidden_state"].mean()
|
||||
loss_fn_for_casual_lm = lambda output: output["loss"]
|
||||
loss_fn_for_seq_classification = lambda output: output["logits"].mean()
|
||||
|
||||
config = CohereConfig(
|
||||
num_hidden_layers=8,
|
||||
hidden_size=32,
|
||||
intermediate_size=64,
|
||||
num_attention_heads=4,
|
||||
max_position_embeddings=128,
|
||||
)
|
||||
|
||||
if hasattr(config, "pad_token_id"):
|
||||
config.pad_token_id = config.eos_token_id
|
||||
|
||||
# register the following models
|
||||
# transformers.CohereModel,
|
||||
# transformers.CohereForCausalLM,
|
||||
model_zoo.register(
|
||||
name="transformers_command",
|
||||
model_fn=lambda: transformers.CohereModel(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_command_for_casual_lm",
|
||||
model_fn=lambda: transformers.CohereForCausalLM(config),
|
||||
data_gen_fn=data_gen_for_casual_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_casual_lm,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
|
@ -0,0 +1,301 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer import PipelineGradientCheckpointConfig
|
||||
from colossalai.shardformer.layer.utils import Randomizer
|
||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import (
|
||||
build_model_from_hybrid_plugin,
|
||||
check_all_grad_tensors,
|
||||
check_loss,
|
||||
check_output_hidden_state,
|
||||
check_weight,
|
||||
get_grad_tensors_for_check,
|
||||
run_forward_backward_with_hybrid_plugin,
|
||||
unwrap_model,
|
||||
)
|
||||
|
||||
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
||||
|
||||
|
||||
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
|
||||
enable_gradient_checkpointing = test_config.pop("enable_gradient_checkpointing", False)
|
||||
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
|
||||
model_fn, loss_fn, test_config
|
||||
)
|
||||
if enable_gradient_checkpointing:
|
||||
# org_model.gradient_checkpointing_enable()
|
||||
sharded_model.unwrap().gradient_checkpointing_enable()
|
||||
|
||||
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
|
||||
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
|
||||
)
|
||||
|
||||
stage_manager = booster.plugin.stage_manager
|
||||
tp_group = booster.plugin.tp_group
|
||||
|
||||
# unwrap model
|
||||
command_model = unwrap_model(org_model, "CohereModel", "model")
|
||||
shard_command_model = unwrap_model(sharded_model, "CohereModel", "model")
|
||||
|
||||
row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"]
|
||||
col_layer_for_check = ["layers[0].self_attn.o_proj"]
|
||||
# Here we check the grad of layernorm because an all-reduce operation should be performed during sequence parallelism
|
||||
norm_layer_for_check = ["layers[0].input_layernorm", "layers[1].input_layernorm"]
|
||||
|
||||
# During pipeline parallelism, we cannot get the grad of norm layer during first stage, so we only check this when pp is not enbaled
|
||||
if stage_manager is None:
|
||||
norm_layer_for_check.append("norm")
|
||||
|
||||
# Check the grad when using ZeRO-1 and ZeRO-2
|
||||
if (
|
||||
booster.plugin.zero_stage in [1, 2]
|
||||
and booster.plugin.shard_config.enable_sequence_parallelism
|
||||
and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all"
|
||||
):
|
||||
for p1, p2 in zip(command_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]):
|
||||
working_p = sharded_optimizer._param_store.master_to_working_param[id(p2)]
|
||||
grads = sharded_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(working_p))
|
||||
grad_index = (
|
||||
0 if sharded_optimizer._grad_store._partition_grads else sharded_optimizer._bucket_store.zero_local_rank
|
||||
)
|
||||
grad = grads[grad_index]
|
||||
sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()]
|
||||
assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False)
|
||||
|
||||
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
|
||||
grads_to_check = {}
|
||||
if (stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True)) and booster.plugin.zero_stage == 0:
|
||||
if test_config["precision"] == "fp32":
|
||||
atol, rtol = 1e-6, 1e-4
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
row_layer_grads = get_grad_tensors_for_check(
|
||||
command_model, shard_command_model, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False
|
||||
)
|
||||
col_layer_grads = get_grad_tensors_for_check(
|
||||
command_model, shard_command_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
|
||||
)
|
||||
norm_layer_grads = get_grad_tensors_for_check(
|
||||
command_model,
|
||||
shard_command_model,
|
||||
norm_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=1,
|
||||
verbose=False,
|
||||
)
|
||||
grads_to_check.update(col_layer_grads)
|
||||
grads_to_check.update(row_layer_grads)
|
||||
grads_to_check.update(norm_layer_grads)
|
||||
|
||||
# optimizer executes step
|
||||
org_optimizer.step()
|
||||
sharded_optimizer.step()
|
||||
|
||||
# check last hidden state & loss
|
||||
if stage_manager is None or stage_manager.is_last_stage(ignore_chunk=True):
|
||||
if test_config["precision"] == "fp32":
|
||||
atol, rtol = 1e-5, 1e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
|
||||
if org_model.__class__.__name__ == "CohereModel":
|
||||
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
|
||||
|
||||
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||
|
||||
# check weights
|
||||
if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True):
|
||||
if test_config["precision"] == "fp32":
|
||||
atol, rtol = 1e-4, 1e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
check_weight(
|
||||
command_model, shard_command_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
|
||||
)
|
||||
|
||||
# check grads
|
||||
check_all_grad_tensors(grads_to_check)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 1,
|
||||
"num_microbatches": 1,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "ring",
|
||||
"enable_flash_attention": True,
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 2,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 4,
|
||||
"pp_size": 1,
|
||||
"num_microbatches": 1,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "split_gather",
|
||||
"enable_flash_attention": False,
|
||||
"use_lazy_init": True,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 1,
|
||||
"pp_size": 1,
|
||||
"sp_size": 2,
|
||||
"num_microbatches": 1,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "all_to_all",
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 2,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 2,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": True,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
"enable_gradient_checkpointing": True,
|
||||
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(gradient_checkpointing_ratio=0.5),
|
||||
},
|
||||
{
|
||||
"tp_size": 1,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp32",
|
||||
"enable_gradient_checkpointing": True,
|
||||
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]),
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 1,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 2,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 1,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 2,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 1,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
],
|
||||
)
|
||||
def run_command_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm")
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
|
||||
clear_layout_converter()
|
||||
Randomizer.reset_index()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"enable_all_optimization": False,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp32",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"enable_all_optimization": False,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp16",
|
||||
"zero_stage": 1,
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"pp_style": "interleaved",
|
||||
"num_model_chunks": 2,
|
||||
"num_microbatches": 4,
|
||||
"enable_all_optimization": False,
|
||||
"precision": "fp16",
|
||||
"zero_stage": 1,
|
||||
"initial_scale": 1,
|
||||
"enable_gradient_checkpointing": True,
|
||||
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(
|
||||
num_ckpt_layers_per_stage=[0, 1, 2, 2],
|
||||
),
|
||||
},
|
||||
],
|
||||
)
|
||||
def run_command_3d_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm")
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
|
||||
clear_layout_converter()
|
||||
Randomizer.reset_index()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_command(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_command_test()
|
||||
|
||||
|
||||
def check_command_3d(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_command_3d_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_command():
|
||||
spawn(check_command, 4)
|
||||
|
||||
|
||||
@pytest.mark.largedist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_command_3d():
|
||||
spawn(check_command_3d, 8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_command()
|
||||
test_command_3d()
|
Loading…
Reference in New Issue