mirror of https://github.com/hpcaitech/ColossalAI
Merge pull request #5818 from GuangyaoZhang/command-r
[shardformer] Support the Command-R modelpull/5832/head
commit
639394b0d4
|
@ -140,32 +140,29 @@ class RMSNorm(BaseLayerNorm):
|
||||||
|
|
||||||
class LayerNorm(BaseLayerNorm):
|
class LayerNorm(BaseLayerNorm):
|
||||||
r"""
|
r"""
|
||||||
This is a wrapper around the torch.nn.LayerNorm. It is meant to be used only with the from_native_module interface.
|
This is a wrapper around native LayerNorm. It is meant to be used only with the from_native_module interface.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"LayerNorm is not implemented as a physical class. "
|
"LayerNorm is not implemented as a physical class. "
|
||||||
"It is meant to be used only with the from_native_module interface to convert a native pytorch layer norm module to colossalai layer norm module."
|
"It is meant to be used only with the from_native_module interface to convert a native LayerNorm module to colossalai layer norm module."
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
|
def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
|
||||||
r"""
|
r"""
|
||||||
Convert a native pytorch layer norm module to colossalai layer norm module,
|
Convert a native LayerNorm module to colossalai layer norm module,
|
||||||
and optionally marking parameters for gradient aggregation.
|
and optionally marking parameters for gradient aggregation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
|
module (nn.Module): The native LayerNorm module to be converted.
|
||||||
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
|
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
nn.Module: The LayerNorm module.
|
nn.Module: The colossalai LayerNorm module.
|
||||||
|
|
||||||
Raises:
|
|
||||||
AssertionError: If the provided module is not an instance of nn.LayerNorm.
|
|
||||||
"""
|
"""
|
||||||
assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm."
|
|
||||||
|
|
||||||
LazyInitContext.materialize(module)
|
LazyInitContext.materialize(module)
|
||||||
|
|
||||||
|
@ -174,6 +171,7 @@ class LayerNorm(BaseLayerNorm):
|
||||||
# aggregation of these gradients is necessary during backpropagation.
|
# aggregation of these gradients is necessary during backpropagation.
|
||||||
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
|
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
|
||||||
SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight)
|
SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight)
|
||||||
|
if module.bias is not None:
|
||||||
SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias)
|
SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias)
|
||||||
|
|
||||||
return module
|
return module
|
||||||
|
@ -187,31 +185,29 @@ class FusedLayerNorm(BaseLayerNorm):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"FusedLayerNorm is not implemented as a physical class. "
|
"FusedLayerNorm is not implemented as a physical class. "
|
||||||
"It is meant to be used only with the from_native_module interface convert a native pytorch layer norm module to FusedLayerNorm module provided by apex."
|
"It is meant to be used only with the from_native_module interface convert a native LayerNorm module to FusedLayerNorm module provided by apex."
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
|
def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
|
||||||
r"""
|
r"""
|
||||||
Convert a native pytorch layer norm module to FusedLayerNorm module provided by apex,
|
Convert a native LayerNorm module to FusedLayerNorm module provided by apex,
|
||||||
and optionally marking parameters for gradient aggregation.
|
and optionally marking parameters for gradient aggregation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
|
module (nn.Module): The native LayerNorm module to be converted.
|
||||||
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
|
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
nn.Module: Union[FastLayerNorm, FusedLayerNorm].
|
nn.Module: Union[FastLayerNorm, FusedLayerNorm].
|
||||||
|
|
||||||
Raises:
|
|
||||||
AssertionError: If the provided module is not an instance of nn.LayerNorm.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
LazyInitContext.materialize(module)
|
LazyInitContext.materialize(module)
|
||||||
# get the attributes of the module
|
# get the attributes of the module
|
||||||
normalized_shape = module.normalized_shape
|
normalized_shape = getattr(module, "normalized_shape", module.weight.shape[0])
|
||||||
eps = module.eps
|
eps = module.variance_epsilon if hasattr(module, "variance_epsilon") else module.eps
|
||||||
elementwise_affine = module.elementwise_affine
|
elementwise_affine = getattr(module, "elementwise_affine", True)
|
||||||
dtype = module.weight.dtype
|
dtype = module.weight.dtype
|
||||||
device = module.weight.device
|
device = module.weight.device
|
||||||
|
|
||||||
|
@ -229,7 +225,7 @@ class FusedLayerNorm(BaseLayerNorm):
|
||||||
ApexFusedLayerNorm = FusedLayerNormWithHook
|
ApexFusedLayerNorm = FusedLayerNormWithHook
|
||||||
except NameError:
|
except NameError:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Please install Apex from source to use fused kernels, or set self.enable_fused_normalization = False. Using vanilla layernorm instead."
|
"Please install Apex from source to use fused kernels, or set self.enable_fused_normalization = False. Using native layernorm instead."
|
||||||
)
|
)
|
||||||
return module
|
return module
|
||||||
|
|
||||||
|
@ -237,6 +233,7 @@ class FusedLayerNorm(BaseLayerNorm):
|
||||||
ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device)
|
ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device)
|
||||||
)
|
)
|
||||||
layernorm.weight = module.weight
|
layernorm.weight = module.weight
|
||||||
|
if module.bias is not None:
|
||||||
layernorm.bias = module.bias
|
layernorm.bias = module.bias
|
||||||
|
|
||||||
if sp_partial_derived:
|
if sp_partial_derived:
|
||||||
|
|
|
@ -0,0 +1,692 @@
|
||||||
|
import math
|
||||||
|
import warnings
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import CrossEntropyLoss
|
||||||
|
from transformers.cache_utils import Cache, DynamicCache
|
||||||
|
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||||
|
from transformers.models.cohere.modeling_cohere import (
|
||||||
|
CohereForCausalLM,
|
||||||
|
CohereModel,
|
||||||
|
StaticCache,
|
||||||
|
apply_rotary_pos_emb,
|
||||||
|
repeat_kv,
|
||||||
|
)
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
from colossalai.shardformer.layer._operation import (
|
||||||
|
all_to_all_comm,
|
||||||
|
gather_forward_split_backward,
|
||||||
|
split_forward_gather_backward,
|
||||||
|
)
|
||||||
|
from colossalai.shardformer.shard import ShardConfig
|
||||||
|
|
||||||
|
from ..layer import ColoAttention, cross_entropy_1d
|
||||||
|
|
||||||
|
|
||||||
|
class CommandPipelineForwards:
|
||||||
|
"""
|
||||||
|
This class serves as a micro library for forward function substitution of Command models
|
||||||
|
under pipeline setting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def command_model_forward(
|
||||||
|
self: CohereModel,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
stage_index: Optional[List[int]] = None,
|
||||||
|
shard_config: ShardConfig = None,
|
||||||
|
):
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
logger.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with pipeline parallelism. Setting `use_cache=False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# retrieve input_ids and inputs_embeds
|
||||||
|
if stage_manager.is_first_stage():
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
batch_size, seq_length = input_ids.shape[:2]
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
batch_size, seq_length, _ = inputs_embeds.shape[:2]
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
else:
|
||||||
|
input_shape = hidden_states.shape[:-1]
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
|
device = hidden_states.device
|
||||||
|
|
||||||
|
past_seen_tokens = 0
|
||||||
|
if use_cache: # kept for BC (cache positions)
|
||||||
|
if not isinstance(past_key_values, StaticCache):
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length()
|
||||||
|
if cache_position is None:
|
||||||
|
if isinstance(past_key_values, StaticCache):
|
||||||
|
raise ValueError("cache_position is a required argument when using StaticCache.")
|
||||||
|
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=device)
|
||||||
|
|
||||||
|
seq_length_with_past = seq_length + past_seen_tokens
|
||||||
|
|
||||||
|
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||||
|
if output_attentions:
|
||||||
|
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||||
|
output_attentions = False
|
||||||
|
if output_hidden_states:
|
||||||
|
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||||
|
output_hidden_states = False
|
||||||
|
if use_cache:
|
||||||
|
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
|
# embed positions, for the first stage, hidden_states is the input embeddings,
|
||||||
|
# for the other stages, hidden_states is the output of the previous stage
|
||||||
|
if shard_config.enable_flash_attention:
|
||||||
|
# in this case, attention_mask is a dict rather than a tensor
|
||||||
|
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
|
||||||
|
attention_mask = ColoAttention.prepare_attn_kwargs(
|
||||||
|
mask_shape,
|
||||||
|
hidden_states.dtype,
|
||||||
|
hidden_states.device,
|
||||||
|
q_padding_mask=attention_mask,
|
||||||
|
is_causal=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position)
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training and use_cache:
|
||||||
|
if use_cache:
|
||||||
|
logger.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
# decoder layers
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
next_decoder_cache = None
|
||||||
|
|
||||||
|
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||||
|
num_ckpt_layers = 0
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
num_ckpt_layers = end_idx - start_idx
|
||||||
|
# TODO: We can replace `gradient_checkpointing_enable` fn and initialize a gradient_checkpointing (List[bool]) for each layer
|
||||||
|
if shard_config.gradient_checkpoint_config is not None:
|
||||||
|
num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
|
||||||
|
stage=stage_manager.stage,
|
||||||
|
num_stages=stage_manager.num_stages,
|
||||||
|
num_layers=end_idx - start_idx,
|
||||||
|
model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0),
|
||||||
|
num_model_chunks=stage_manager.num_model_chunks,
|
||||||
|
)
|
||||||
|
assert num_ckpt_layers <= end_idx - start_idx
|
||||||
|
|
||||||
|
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
if idx - start_idx < num_ckpt_layers:
|
||||||
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
|
decoder_layer.__call__,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
|
past_key_values,
|
||||||
|
output_attentions,
|
||||||
|
use_cache,
|
||||||
|
cache_position,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_values,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
|
if stage_manager.is_last_stage():
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
# add hidden states from the last decoder layer
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
|
if stage_manager.is_last_stage():
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(
|
||||||
|
v
|
||||||
|
for v in [
|
||||||
|
hidden_states,
|
||||||
|
next_cache,
|
||||||
|
all_hidden_states,
|
||||||
|
all_self_attns,
|
||||||
|
]
|
||||||
|
if v is not None
|
||||||
|
)
|
||||||
|
return BaseModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
|
)
|
||||||
|
# always return dict for imediate stage
|
||||||
|
return {"hidden_states": hidden_states}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def command_for_causal_lm_forward(
|
||||||
|
self: CohereForCausalLM,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
stage_index: Optional[List[int]] = None,
|
||||||
|
shard_config: ShardConfig = None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer, CohereForCausalLM
|
||||||
|
|
||||||
|
>>> model = CohereForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||||
|
|
||||||
|
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||||
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||||
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||||
|
```"""
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||||
|
if output_attentions:
|
||||||
|
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||||
|
output_attentions = False
|
||||||
|
if output_hidden_states:
|
||||||
|
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||||
|
output_hidden_states = False
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = CommandPipelineForwards.command_model_forward(
|
||||||
|
self.model,
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
stage_manager=stage_manager,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
stage_index=stage_index,
|
||||||
|
shard_config=shard_config,
|
||||||
|
)
|
||||||
|
past_key_values = None
|
||||||
|
|
||||||
|
if stage_manager.is_last_stage():
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
logits = logits * self.logit_scale
|
||||||
|
logits = logits.float()
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
# Enable model parallelism
|
||||||
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
|
||||||
|
new_vocab_size = logits.shape[-1]
|
||||||
|
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||||
|
loss = cross_entropy_1d(
|
||||||
|
shift_logits,
|
||||||
|
shift_labels,
|
||||||
|
process_group=shard_config.tensor_parallel_process_group,
|
||||||
|
vocab_size=self.lm_head.out_features,
|
||||||
|
dtype=self.model.dtype,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||||
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states = outputs.get("hidden_states")
|
||||||
|
return {"hidden_states": hidden_states}
|
||||||
|
|
||||||
|
|
||||||
|
def get_command_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
||||||
|
if sp_mode is not None:
|
||||||
|
assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode"
|
||||||
|
assert (sp_size is not None) and (
|
||||||
|
sp_group is not None
|
||||||
|
), "Must specify sp_size and sp_group for sequence parallel"
|
||||||
|
if "padding_mask" in kwargs:
|
||||||
|
warnings.warn(
|
||||||
|
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||||
|
)
|
||||||
|
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
# sp: modify sp_len when sequence parallel mode is ring
|
||||||
|
if sp_mode in ["split_gather", "ring"]:
|
||||||
|
q_len *= sp_size
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
# sp: all-to-all comminucation when introducing sequence parallel
|
||||||
|
if sp_mode == "all_to_all":
|
||||||
|
query_states = all_to_all_comm(query_states, sp_group)
|
||||||
|
key_states = all_to_all_comm(key_states, sp_group)
|
||||||
|
value_states = all_to_all_comm(value_states, sp_group)
|
||||||
|
bsz, q_len, _ = query_states.size()
|
||||||
|
|
||||||
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
if self.layer_idx is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||||
|
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||||
|
"with a layer index."
|
||||||
|
)
|
||||||
|
|
||||||
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
|
|
||||||
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||||
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
if shard_config.enable_flash_attention:
|
||||||
|
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
|
||||||
|
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
|
||||||
|
else:
|
||||||
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||||
|
f" {attn_weights.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||||
|
)
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
# sp: all-to-all comminucation when introducing sequence parallel
|
||||||
|
if sp_mode == "all_to_all":
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
|
||||||
|
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2)
|
||||||
|
else:
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
return forward
|
||||||
|
|
||||||
|
|
||||||
|
def get_command_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# retrieve input_ids and inputs_embeds
|
||||||
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
|
raise ValueError(
|
||||||
|
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
||||||
|
)
|
||||||
|
|
||||||
|
if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training:
|
||||||
|
if use_cache:
|
||||||
|
logger.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
past_seen_tokens = 0
|
||||||
|
seq_len = inputs_embeds.shape[1]
|
||||||
|
if use_cache: # kept for BC (cache positions)
|
||||||
|
if not isinstance(past_key_values, StaticCache):
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length()
|
||||||
|
if cache_position is None:
|
||||||
|
if isinstance(past_key_values, StaticCache):
|
||||||
|
raise ValueError("cache_position is a required argument when using StaticCache.")
|
||||||
|
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_len, device=inputs_embeds.device)
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
|
# in this case, attention_mask is a dict rather than a tensor
|
||||||
|
if shard_config.enable_flash_attention:
|
||||||
|
mask_shape = (inputs_embeds.shape[0], 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len)
|
||||||
|
attention_mask = ColoAttention.prepare_attn_kwargs(
|
||||||
|
mask_shape,
|
||||||
|
inputs_embeds.dtype,
|
||||||
|
inputs_embeds.device,
|
||||||
|
q_padding_mask=attention_mask,
|
||||||
|
is_causal=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
|
||||||
|
|
||||||
|
if sp_mode in ["ring", "split_gather"]:
|
||||||
|
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
|
||||||
|
elif sp_mode == "all_to_all":
|
||||||
|
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
# decoder layers
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
next_decoder_cache = None
|
||||||
|
|
||||||
|
for decoder_layer in self.layers:
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
|
decoder_layer.__call__,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
|
past_key_values,
|
||||||
|
output_attentions,
|
||||||
|
use_cache,
|
||||||
|
cache_position,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_values,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
if sp_mode == "ring" or sp_mode == "split_gather":
|
||||||
|
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
|
||||||
|
elif sp_mode == "all_to_all":
|
||||||
|
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
|
||||||
|
|
||||||
|
# add hidden states from the last decoder layer
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
next_cache = None
|
||||||
|
if use_cache:
|
||||||
|
next_cache = (
|
||||||
|
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
|
||||||
|
)
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||||
|
|
||||||
|
return BaseModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
|
)
|
||||||
|
|
||||||
|
return forward
|
||||||
|
|
||||||
|
|
||||||
|
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||||
|
from transformers import CohereForCausalLM
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self: CohereForCausalLM,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer, CohereForCausalLM
|
||||||
|
|
||||||
|
>>> model = CohereForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||||
|
|
||||||
|
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||||
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||||
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||||
|
```"""
|
||||||
|
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
logits = logits * self.logit_scale
|
||||||
|
logits = logits.float()
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
# Enable model parallelism
|
||||||
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
new_vocab_size = logits.shape[-1]
|
||||||
|
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||||
|
loss = cross_entropy_1d(
|
||||||
|
shift_logits,
|
||||||
|
shift_labels,
|
||||||
|
process_group=shard_config.tensor_parallel_process_group,
|
||||||
|
vocab_size=self.lm_head.out_features,
|
||||||
|
dtype=self.model.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
return forward
|
|
@ -192,6 +192,13 @@ _POLICY_LIST = {
|
||||||
"transformers.models.qwen2.modeling_qwen2.Qwen2ForSequenceClassification": PolicyLocation(
|
"transformers.models.qwen2.modeling_qwen2.Qwen2ForSequenceClassification": PolicyLocation(
|
||||||
file_name="qwen2", class_name="Qwen2ForSequenceClassificationPolicy"
|
file_name="qwen2", class_name="Qwen2ForSequenceClassificationPolicy"
|
||||||
),
|
),
|
||||||
|
# Command-R
|
||||||
|
"transformers.models.cohere.modeling_cohere.CohereModel": PolicyLocation(
|
||||||
|
file_name="command", class_name="CommandModelPolicy"
|
||||||
|
),
|
||||||
|
"transformers.models.cohere.modeling_cohere.CohereForCausalLM": PolicyLocation(
|
||||||
|
file_name="command", class_name="CommandForCausalLMPolicy"
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -67,7 +67,7 @@ class BertPolicy(Policy):
|
||||||
else:
|
else:
|
||||||
norm_cls = col_nn.LayerNorm
|
norm_cls = col_nn.LayerNorm
|
||||||
|
|
||||||
sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None
|
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
||||||
assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for Bert"
|
assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for Bert"
|
||||||
if sp_mode == "ring":
|
if sp_mode == "ring":
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
|
|
@ -50,7 +50,7 @@ class BloomPolicy(Policy):
|
||||||
else:
|
else:
|
||||||
norm_cls = col_nn.LayerNorm
|
norm_cls = col_nn.LayerNorm
|
||||||
|
|
||||||
sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None
|
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
||||||
assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for BLOOM"
|
assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for BLOOM"
|
||||||
if sp_mode == "ring":
|
if sp_mode == "ring":
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
|
|
@ -57,7 +57,7 @@ class ChatGLMPolicy(Policy):
|
||||||
else:
|
else:
|
||||||
norm_cls = col_nn.LayerNorm
|
norm_cls = col_nn.LayerNorm
|
||||||
|
|
||||||
sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None
|
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
||||||
assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for ChatGLM2"
|
assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for ChatGLM2"
|
||||||
if sp_mode == "ring":
|
if sp_mode == "ring":
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
|
|
@ -0,0 +1,369 @@
|
||||||
|
import warnings
|
||||||
|
from functools import partial
|
||||||
|
from typing import Callable, Dict, List, Union
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.nn import Module
|
||||||
|
|
||||||
|
from colossalai.shardformer.layer import (
|
||||||
|
FusedLayerNorm,
|
||||||
|
LayerNorm,
|
||||||
|
Linear1D_Col,
|
||||||
|
Linear1D_Row,
|
||||||
|
PaddingEmbedding,
|
||||||
|
PaddingLMHead,
|
||||||
|
VocabParallelEmbedding1D,
|
||||||
|
VocabParallelLMHead1D,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ..modeling.command import (
|
||||||
|
CommandPipelineForwards,
|
||||||
|
get_command_flash_attention_forward,
|
||||||
|
get_command_flash_attention_model_forward,
|
||||||
|
get_lm_forward_with_dist_cross_entropy,
|
||||||
|
)
|
||||||
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
|
__all__ = ["CommandPolicy", "CommandForCausalLMPolicy"]
|
||||||
|
|
||||||
|
|
||||||
|
class CommandPolicy(Policy):
|
||||||
|
def config_sanity_check(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def preprocess(self):
|
||||||
|
self.tie_weight = self.tie_weight_check()
|
||||||
|
self.origin_attn_implement = self.model.config._attn_implementation
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||||
|
from transformers.models.cohere.modeling_cohere import (
|
||||||
|
CohereAttention,
|
||||||
|
CohereDecoderLayer,
|
||||||
|
CohereFlashAttention2,
|
||||||
|
CohereModel,
|
||||||
|
CohereSdpaAttention,
|
||||||
|
)
|
||||||
|
|
||||||
|
ATTN_IMPLEMENTATION = {
|
||||||
|
"eager": CohereAttention,
|
||||||
|
"flash_attention_2": CohereFlashAttention2,
|
||||||
|
"sdpa": CohereSdpaAttention,
|
||||||
|
}
|
||||||
|
policy = {}
|
||||||
|
|
||||||
|
attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
|
||||||
|
embedding_cls = None
|
||||||
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
embedding_cls = VocabParallelEmbedding1D
|
||||||
|
else:
|
||||||
|
if self.tie_weight:
|
||||||
|
embedding_cls = PaddingEmbedding
|
||||||
|
|
||||||
|
if self.shard_config.enable_fused_normalization:
|
||||||
|
norm_cls = FusedLayerNorm
|
||||||
|
else:
|
||||||
|
norm_cls = LayerNorm
|
||||||
|
|
||||||
|
if self.pipeline_stage_manager is not None:
|
||||||
|
self.shard_config.enable_sequence_parallelism = False
|
||||||
|
self.shard_config.enable_sequence_overlap = False
|
||||||
|
self.shard_config.sequence_parallelism_mode = None
|
||||||
|
warnings.warn(
|
||||||
|
f"For Command, sequence parallelism is currently not compatible with pipeline parallelism, set to be False"
|
||||||
|
)
|
||||||
|
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
||||||
|
sp_size = self.shard_config.sequence_parallel_size or None
|
||||||
|
sp_group = self.shard_config.sequence_parallel_process_group or None
|
||||||
|
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
||||||
|
|
||||||
|
if sp_mode == "all_to_all":
|
||||||
|
decoder_attribute_replacement = {
|
||||||
|
"num_heads": self.model.config.num_attention_heads // sp_size,
|
||||||
|
}
|
||||||
|
if getattr(self.model.config, "num_key_value_heads", False):
|
||||||
|
decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size
|
||||||
|
|
||||||
|
policy[attn_cls] = ModulePolicyDescription(
|
||||||
|
attribute_replacement=decoder_attribute_replacement,
|
||||||
|
)
|
||||||
|
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
|
||||||
|
self.append_or_create_method_replacement(
|
||||||
|
description={
|
||||||
|
"forward": get_command_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
|
||||||
|
},
|
||||||
|
policy=policy,
|
||||||
|
target_key=attn_cls,
|
||||||
|
)
|
||||||
|
if self.pipeline_stage_manager is None:
|
||||||
|
self.append_or_create_method_replacement(
|
||||||
|
description={
|
||||||
|
"forward": get_command_flash_attention_model_forward(
|
||||||
|
self.shard_config,
|
||||||
|
sp_mode=sp_mode,
|
||||||
|
sp_size=sp_size,
|
||||||
|
sp_group=sp_group,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
policy=policy,
|
||||||
|
target_key=CohereModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
assert (
|
||||||
|
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||||
|
), f"The number of attention heads must be divisible by tensor parallel size."
|
||||||
|
if hasattr(self.model.config, "num_key_value_heads"):
|
||||||
|
assert (
|
||||||
|
self.model.config.num_key_value_heads >= self.shard_config.tensor_parallel_size
|
||||||
|
and self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
|
||||||
|
), f"The number of key_value heads must be divisible by, and must not be less than tensor parallel size."
|
||||||
|
decoder_attribute_replacement = {
|
||||||
|
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||||
|
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||||
|
}
|
||||||
|
if getattr(self.model.config, "num_key_value_heads", False):
|
||||||
|
decoder_attribute_replacement["self_attn.num_key_value_heads"] = (
|
||||||
|
self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size
|
||||||
|
)
|
||||||
|
|
||||||
|
policy[CohereDecoderLayer] = ModulePolicyDescription(
|
||||||
|
attribute_replacement=decoder_attribute_replacement,
|
||||||
|
sub_module_replacement=[
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn.q_proj",
|
||||||
|
target_module=Linear1D_Col,
|
||||||
|
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn.k_proj",
|
||||||
|
target_module=Linear1D_Col,
|
||||||
|
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn.v_proj",
|
||||||
|
target_module=Linear1D_Col,
|
||||||
|
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn.o_proj",
|
||||||
|
target_module=Linear1D_Row,
|
||||||
|
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="mlp.gate_proj",
|
||||||
|
target_module=Linear1D_Col,
|
||||||
|
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="mlp.up_proj",
|
||||||
|
target_module=Linear1D_Col,
|
||||||
|
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="mlp.down_proj",
|
||||||
|
target_module=Linear1D_Row,
|
||||||
|
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
if embedding_cls is not None:
|
||||||
|
self.append_or_create_submodule_replacement(
|
||||||
|
description=SubModuleReplacementDescription(
|
||||||
|
suffix="embed_tokens",
|
||||||
|
target_module=embedding_cls,
|
||||||
|
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||||
|
),
|
||||||
|
policy=policy,
|
||||||
|
target_key=CohereModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
# optimization configuration
|
||||||
|
self.append_or_create_submodule_replacement(
|
||||||
|
description=[
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="input_layernorm",
|
||||||
|
target_module=norm_cls,
|
||||||
|
kwargs={"sp_partial_derived": sp_partial_derived},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
policy=policy,
|
||||||
|
target_key=CohereDecoderLayer,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.append_or_create_submodule_replacement(
|
||||||
|
description=SubModuleReplacementDescription(
|
||||||
|
suffix="norm",
|
||||||
|
target_module=norm_cls,
|
||||||
|
kwargs={"sp_partial_derived": sp_partial_derived},
|
||||||
|
),
|
||||||
|
policy=policy,
|
||||||
|
target_key=CohereModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
return policy
|
||||||
|
|
||||||
|
def postprocess(self):
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
|
||||||
|
"""If under pipeline parallel setting, replacing the original forward method of huggingface
|
||||||
|
to customized forward method, and add this changing to policy."""
|
||||||
|
if self.pipeline_stage_manager is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
if self.model.__class__.__name__ == "CohereModel":
|
||||||
|
module = self.model
|
||||||
|
else:
|
||||||
|
module = self.model.model
|
||||||
|
|
||||||
|
if stage_manager.is_interleave:
|
||||||
|
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||||
|
stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage)
|
||||||
|
method_replacement = {
|
||||||
|
"forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config)
|
||||||
|
}
|
||||||
|
|
||||||
|
else:
|
||||||
|
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||||
|
stage_index = stage_manager.get_stage_index(layers_per_stage)
|
||||||
|
method_replacement = {
|
||||||
|
"forward": partial(
|
||||||
|
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
|
||||||
|
|
||||||
|
def get_held_layers(self) -> List[Module]:
|
||||||
|
"""Get pipeline layers for current stage."""
|
||||||
|
assert self.pipeline_stage_manager is not None
|
||||||
|
|
||||||
|
if self.model.__class__.__name__ == "CohereModel":
|
||||||
|
module = self.model
|
||||||
|
else:
|
||||||
|
module = self.model.model
|
||||||
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
|
||||||
|
held_layers = []
|
||||||
|
if stage_manager.is_interleave:
|
||||||
|
assert stage_manager.num_model_chunks is not None
|
||||||
|
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||||
|
stage_indices = stage_manager.get_stage_index(layers_per_stage)
|
||||||
|
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
|
held_layers.append(module.embed_tokens)
|
||||||
|
for start_idx, end_idx in stage_indices:
|
||||||
|
held_layers.extend(module.layers[start_idx:end_idx])
|
||||||
|
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
|
held_layers.append(module.norm)
|
||||||
|
|
||||||
|
else:
|
||||||
|
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||||
|
if stage_manager.is_first_stage():
|
||||||
|
held_layers.append(module.embed_tokens)
|
||||||
|
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||||
|
held_layers.extend(module.layers[start_idx:end_idx])
|
||||||
|
if stage_manager.is_last_stage():
|
||||||
|
held_layers.append(module.norm)
|
||||||
|
|
||||||
|
return held_layers
|
||||||
|
|
||||||
|
|
||||||
|
class CommandModelPolicy(CommandPolicy):
|
||||||
|
def module_policy(self):
|
||||||
|
policy = super().module_policy()
|
||||||
|
from transformers.models.cohere.modeling_cohere import CohereModel
|
||||||
|
|
||||||
|
if self.pipeline_stage_manager:
|
||||||
|
# set None as default
|
||||||
|
self.set_pipeline_forward(
|
||||||
|
model_cls=CohereModel, new_forward=CommandPipelineForwards.command_model_forward, policy=policy
|
||||||
|
)
|
||||||
|
return policy
|
||||||
|
|
||||||
|
def get_held_layers(self) -> List[Module]:
|
||||||
|
"""Get pipeline layers for current stage."""
|
||||||
|
held_layers = super().get_held_layers()
|
||||||
|
return held_layers
|
||||||
|
|
||||||
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||||
|
"""No shared params in command model"""
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class CommandForCausalLMPolicy(CommandPolicy):
|
||||||
|
def module_policy(self):
|
||||||
|
from transformers import CohereForCausalLM
|
||||||
|
|
||||||
|
policy = super().module_policy()
|
||||||
|
|
||||||
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
# add a new item for casual lm
|
||||||
|
new_item = {
|
||||||
|
CohereForCausalLM: ModulePolicyDescription(
|
||||||
|
sub_module_replacement=[
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="lm_head",
|
||||||
|
target_module=VocabParallelLMHead1D,
|
||||||
|
kwargs={
|
||||||
|
"gather_output": not self.shard_config.parallel_output,
|
||||||
|
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if self.shard_config.parallel_output:
|
||||||
|
new_item[CohereForCausalLM].method_replacement = {
|
||||||
|
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
new_item = {
|
||||||
|
CohereForCausalLM: ModulePolicyDescription(
|
||||||
|
sub_module_replacement=[
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="lm_head",
|
||||||
|
target_module=PaddingLMHead,
|
||||||
|
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
}
|
||||||
|
policy.update(new_item)
|
||||||
|
|
||||||
|
if self.pipeline_stage_manager:
|
||||||
|
# set None as default
|
||||||
|
self.set_pipeline_forward(
|
||||||
|
model_cls=CohereForCausalLM,
|
||||||
|
new_forward=CommandPipelineForwards.command_for_causal_lm_forward,
|
||||||
|
policy=policy,
|
||||||
|
)
|
||||||
|
|
||||||
|
return policy
|
||||||
|
|
||||||
|
def get_held_layers(self) -> List[Module]:
|
||||||
|
"""Get pipeline layers for current stage."""
|
||||||
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
held_layers = super().get_held_layers()
|
||||||
|
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
|
held_layers.append(self.model.lm_head)
|
||||||
|
return held_layers
|
||||||
|
|
||||||
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||||
|
command_model = self.model.model
|
||||||
|
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
|
||||||
|
if (
|
||||||
|
id(command_model.embed_tokens.weight) == id(self.model.lm_head.weight)
|
||||||
|
and self.pipeline_stage_manager.num_stages > 1
|
||||||
|
):
|
||||||
|
# tie weights
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
0: command_model.embed_tokens.weight,
|
||||||
|
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
return []
|
|
@ -65,7 +65,7 @@ class GPT2Policy(Policy):
|
||||||
else:
|
else:
|
||||||
norm_cls = col_nn.LayerNorm
|
norm_cls = col_nn.LayerNorm
|
||||||
|
|
||||||
sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None
|
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
||||||
assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for GPT2"
|
assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for GPT2"
|
||||||
if sp_mode == "ring":
|
if sp_mode == "ring":
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
|
|
@ -73,11 +73,9 @@ class LlamaPolicy(Policy):
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"For llama, sequence parallelism is currently not compatible with pipeline parallelism, set to be False"
|
f"For llama, sequence parallelism is currently not compatible with pipeline parallelism, set to be False"
|
||||||
)
|
)
|
||||||
sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None
|
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
||||||
sp_size = self.shard_config.sequence_parallel_size if self.shard_config.enable_sequence_parallelism else None
|
sp_size = self.shard_config.sequence_parallel_size or None
|
||||||
sp_group = (
|
sp_group = self.shard_config.sequence_parallel_process_group or None
|
||||||
self.shard_config.sequence_parallel_process_group if self.shard_config.enable_sequence_parallelism else None
|
|
||||||
)
|
|
||||||
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
||||||
|
|
||||||
if sp_mode == "all_to_all":
|
if sp_mode == "all_to_all":
|
||||||
|
|
|
@ -22,3 +22,9 @@ try:
|
||||||
from .qwen2 import *
|
from .qwen2 import *
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("This version of transformers doesn't support qwen2.")
|
print("This version of transformers doesn't support qwen2.")
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .command import *
|
||||||
|
except ImportError:
|
||||||
|
print("This version of transformers doesn't support Command-R.")
|
||||||
|
|
|
@ -0,0 +1,79 @@
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
from ..registry import ModelAttribute, model_zoo
|
||||||
|
|
||||||
|
try:
|
||||||
|
from transformers import CohereConfig
|
||||||
|
|
||||||
|
HAS_COMMAND = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_COMMAND = False
|
||||||
|
|
||||||
|
if HAS_COMMAND:
|
||||||
|
# ===============================
|
||||||
|
# Register Command-R
|
||||||
|
# ===============================
|
||||||
|
|
||||||
|
def data_gen():
|
||||||
|
input_ids = torch.Tensor(
|
||||||
|
[
|
||||||
|
[1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082],
|
||||||
|
[1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082],
|
||||||
|
]
|
||||||
|
).long()
|
||||||
|
|
||||||
|
attention_mask = torch.Tensor(
|
||||||
|
[
|
||||||
|
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||||
|
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||||
|
]
|
||||||
|
).long()
|
||||||
|
|
||||||
|
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
|
||||||
|
# label is needed for casual lm
|
||||||
|
def data_gen_for_casual_lm():
|
||||||
|
data = data_gen()
|
||||||
|
labels = data["input_ids"].clone()
|
||||||
|
data["labels"] = labels
|
||||||
|
return data
|
||||||
|
|
||||||
|
# transform the output to a dict
|
||||||
|
output_transform_fn = lambda x: x
|
||||||
|
|
||||||
|
# function to get the loss
|
||||||
|
loss_fn = lambda output: output["last_hidden_state"].mean()
|
||||||
|
loss_fn_for_casual_lm = lambda output: output["loss"]
|
||||||
|
loss_fn_for_seq_classification = lambda output: output["logits"].mean()
|
||||||
|
|
||||||
|
config = CohereConfig(
|
||||||
|
num_hidden_layers=8,
|
||||||
|
hidden_size=32,
|
||||||
|
intermediate_size=64,
|
||||||
|
num_attention_heads=4,
|
||||||
|
max_position_embeddings=128,
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(config, "pad_token_id"):
|
||||||
|
config.pad_token_id = config.eos_token_id
|
||||||
|
|
||||||
|
# register the following models
|
||||||
|
# transformers.CohereModel,
|
||||||
|
# transformers.CohereForCausalLM,
|
||||||
|
model_zoo.register(
|
||||||
|
name="transformers_command",
|
||||||
|
model_fn=lambda: transformers.CohereModel(config),
|
||||||
|
data_gen_fn=data_gen,
|
||||||
|
output_transform_fn=output_transform_fn,
|
||||||
|
loss_fn=loss_fn,
|
||||||
|
model_attribute=ModelAttribute(has_control_flow=True),
|
||||||
|
)
|
||||||
|
model_zoo.register(
|
||||||
|
name="transformers_command_for_casual_lm",
|
||||||
|
model_fn=lambda: transformers.CohereForCausalLM(config),
|
||||||
|
data_gen_fn=data_gen_for_casual_lm,
|
||||||
|
output_transform_fn=output_transform_fn,
|
||||||
|
loss_fn=loss_fn_for_casual_lm,
|
||||||
|
model_attribute=ModelAttribute(has_control_flow=True),
|
||||||
|
)
|
|
@ -0,0 +1,322 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.testing import assert_close
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.shardformer import PipelineGradientCheckpointConfig
|
||||||
|
from colossalai.shardformer.layer.utils import Randomizer
|
||||||
|
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||||
|
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
from tests.test_shardformer.test_model._utils import (
|
||||||
|
build_model_from_hybrid_plugin,
|
||||||
|
check_all_grad_tensors,
|
||||||
|
check_loss,
|
||||||
|
check_output_hidden_state,
|
||||||
|
check_weight,
|
||||||
|
get_grad_tensors_for_check,
|
||||||
|
run_forward_backward_with_hybrid_plugin,
|
||||||
|
unwrap_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
|
||||||
|
enable_gradient_checkpointing = test_config.pop("enable_gradient_checkpointing", False)
|
||||||
|
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
|
||||||
|
model_fn, loss_fn, test_config
|
||||||
|
)
|
||||||
|
if enable_gradient_checkpointing:
|
||||||
|
# org_model.gradient_checkpointing_enable()
|
||||||
|
sharded_model.unwrap().gradient_checkpointing_enable()
|
||||||
|
|
||||||
|
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
|
||||||
|
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
|
||||||
|
)
|
||||||
|
|
||||||
|
stage_manager = booster.plugin.stage_manager
|
||||||
|
tp_group = booster.plugin.tp_group
|
||||||
|
|
||||||
|
# unwrap model
|
||||||
|
command_model = unwrap_model(org_model, "CohereModel", "model")
|
||||||
|
shard_command_model = unwrap_model(sharded_model, "CohereModel", "model")
|
||||||
|
|
||||||
|
row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"]
|
||||||
|
col_layer_for_check = ["layers[0].self_attn.o_proj"]
|
||||||
|
# Here we check the grad of layernorm because an all-reduce operation should be performed during sequence parallelism
|
||||||
|
norm_layer_for_check = ["layers[0].input_layernorm", "layers[1].input_layernorm"]
|
||||||
|
|
||||||
|
# During pipeline parallelism, we cannot get the grad of norm layer during first stage, so we only check this when pp is not enbaled
|
||||||
|
if stage_manager is None:
|
||||||
|
norm_layer_for_check.append("norm")
|
||||||
|
|
||||||
|
# Check the grad when using ZeRO-1 and ZeRO-2
|
||||||
|
if (
|
||||||
|
booster.plugin.zero_stage in [1, 2]
|
||||||
|
and booster.plugin.shard_config.enable_sequence_parallelism
|
||||||
|
and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all"
|
||||||
|
):
|
||||||
|
for p1, p2 in zip(command_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]):
|
||||||
|
working_p = sharded_optimizer._param_store.master_to_working_param[id(p2)]
|
||||||
|
grads = sharded_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(working_p))
|
||||||
|
grad_index = (
|
||||||
|
0 if sharded_optimizer._grad_store._partition_grads else sharded_optimizer._bucket_store.zero_local_rank
|
||||||
|
)
|
||||||
|
grad = grads[grad_index]
|
||||||
|
sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()]
|
||||||
|
assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False)
|
||||||
|
|
||||||
|
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
|
||||||
|
grads_to_check = {}
|
||||||
|
if (stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True)) and booster.plugin.zero_stage == 0:
|
||||||
|
if test_config["precision"] == "fp32":
|
||||||
|
atol, rtol = 1e-6, 1e-4
|
||||||
|
else:
|
||||||
|
atol, rtol = 5e-3, 5e-3
|
||||||
|
row_layer_grads = get_grad_tensors_for_check(
|
||||||
|
command_model,
|
||||||
|
shard_command_model,
|
||||||
|
row_layer_for_check,
|
||||||
|
tp_group,
|
||||||
|
atol=atol,
|
||||||
|
rtol=rtol,
|
||||||
|
dim=0,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
col_layer_grads = get_grad_tensors_for_check(
|
||||||
|
command_model,
|
||||||
|
shard_command_model,
|
||||||
|
col_layer_for_check,
|
||||||
|
tp_group,
|
||||||
|
atol=atol,
|
||||||
|
rtol=rtol,
|
||||||
|
dim=1,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
norm_layer_grads = get_grad_tensors_for_check(
|
||||||
|
command_model,
|
||||||
|
shard_command_model,
|
||||||
|
norm_layer_for_check,
|
||||||
|
tp_group,
|
||||||
|
atol=atol,
|
||||||
|
rtol=rtol,
|
||||||
|
dim=1,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
grads_to_check.update(col_layer_grads)
|
||||||
|
grads_to_check.update(row_layer_grads)
|
||||||
|
grads_to_check.update(norm_layer_grads)
|
||||||
|
|
||||||
|
# optimizer executes step
|
||||||
|
org_optimizer.step()
|
||||||
|
sharded_optimizer.step()
|
||||||
|
|
||||||
|
# check last hidden state & loss
|
||||||
|
if stage_manager is None or stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
|
if test_config["precision"] == "fp32":
|
||||||
|
atol, rtol = 1e-5, 1e-3
|
||||||
|
else:
|
||||||
|
atol, rtol = 5e-3, 5e-3
|
||||||
|
|
||||||
|
if org_model.__class__.__name__ == "CohereModel":
|
||||||
|
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
# check weights
|
||||||
|
if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
|
if test_config["precision"] == "fp32":
|
||||||
|
atol, rtol = 1e-4, 1e-3
|
||||||
|
else:
|
||||||
|
atol, rtol = 5e-3, 5e-3
|
||||||
|
check_weight(
|
||||||
|
command_model,
|
||||||
|
shard_command_model,
|
||||||
|
col_layer_for_check,
|
||||||
|
tp_group,
|
||||||
|
atol=atol,
|
||||||
|
rtol=rtol,
|
||||||
|
dim=1,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# check grads
|
||||||
|
check_all_grad_tensors(grads_to_check)
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize(
|
||||||
|
"test_config",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 1,
|
||||||
|
"num_microbatches": 1,
|
||||||
|
"enable_sequence_parallelism": True,
|
||||||
|
"sequence_parallelism_mode": "ring",
|
||||||
|
"enable_flash_attention": True,
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"zero_stage": 2,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 4,
|
||||||
|
"pp_size": 1,
|
||||||
|
"num_microbatches": 1,
|
||||||
|
"enable_sequence_parallelism": True,
|
||||||
|
"sequence_parallelism_mode": "split_gather",
|
||||||
|
"enable_flash_attention": False,
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 1,
|
||||||
|
"pp_size": 1,
|
||||||
|
"sp_size": 2,
|
||||||
|
"num_microbatches": 1,
|
||||||
|
"enable_sequence_parallelism": True,
|
||||||
|
"sequence_parallelism_mode": "all_to_all",
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"zero_stage": 2,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 2,
|
||||||
|
"num_microbatches": 2,
|
||||||
|
"enable_all_optimization": True,
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
"enable_gradient_checkpointing": True,
|
||||||
|
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(gradient_checkpointing_ratio=0.5),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 1,
|
||||||
|
"pp_size": 2,
|
||||||
|
"num_microbatches": 4,
|
||||||
|
"use_lazy_init": False,
|
||||||
|
"precision": "fp32",
|
||||||
|
"enable_gradient_checkpointing": True,
|
||||||
|
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 1,
|
||||||
|
"enable_all_optimization": True,
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"zero_stage": 2,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 1,
|
||||||
|
"pp_size": 2,
|
||||||
|
"num_microbatches": 2,
|
||||||
|
"enable_all_optimization": True,
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"zero_stage": 1,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def run_command_test(test_config):
|
||||||
|
sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm")
|
||||||
|
|
||||||
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
|
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||||
|
|
||||||
|
clear_layout_converter()
|
||||||
|
Randomizer.reset_index()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize(
|
||||||
|
"test_config",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 2,
|
||||||
|
"num_microbatches": 4,
|
||||||
|
"enable_all_optimization": False,
|
||||||
|
"use_lazy_init": False,
|
||||||
|
"precision": "fp32",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 2,
|
||||||
|
"num_microbatches": 4,
|
||||||
|
"enable_all_optimization": False,
|
||||||
|
"use_lazy_init": False,
|
||||||
|
"precision": "fp16",
|
||||||
|
"zero_stage": 1,
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 2,
|
||||||
|
"pp_style": "interleaved",
|
||||||
|
"num_model_chunks": 2,
|
||||||
|
"num_microbatches": 4,
|
||||||
|
"enable_all_optimization": False,
|
||||||
|
"precision": "fp16",
|
||||||
|
"zero_stage": 1,
|
||||||
|
"initial_scale": 1,
|
||||||
|
"enable_gradient_checkpointing": True,
|
||||||
|
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(
|
||||||
|
num_ckpt_layers_per_stage=[0, 1, 2, 2],
|
||||||
|
),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def run_command_3d_test(test_config):
|
||||||
|
sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm")
|
||||||
|
|
||||||
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
|
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||||
|
|
||||||
|
clear_layout_converter()
|
||||||
|
Randomizer.reset_index()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
def check_command(rank, world_size, port):
|
||||||
|
disable_existing_loggers()
|
||||||
|
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
|
run_command_test()
|
||||||
|
|
||||||
|
|
||||||
|
def check_command_3d(rank, world_size, port):
|
||||||
|
disable_existing_loggers()
|
||||||
|
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
|
run_command_3d_test()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
@clear_cache_before_run()
|
||||||
|
def test_command():
|
||||||
|
spawn(check_command, 4)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.largedist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
@clear_cache_before_run()
|
||||||
|
def test_command_3d():
|
||||||
|
spawn(check_command_3d, 8)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_command()
|
||||||
|
test_command_3d()
|
Loading…
Reference in New Issue