mirror of https://github.com/hpcaitech/ColossalAI
YeAnbang
5 months ago
64 changed files with 2706 additions and 1071 deletions
@ -0,0 +1,170 @@
|
||||
from abc import ABC, abstractmethod |
||||
from dataclasses import dataclass |
||||
|
||||
import torch |
||||
|
||||
from colossalai.inference.config import ModelShardInferenceConfig |
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader |
||||
from colossalai.kernel.triton import context_attention_unpadded, flash_decoding_attention |
||||
|
||||
|
||||
@dataclass |
||||
class AttentionMetaData: |
||||
query_states: torch.Tensor |
||||
key_states: torch.Tensor |
||||
value_states: torch.Tensor |
||||
k_cache: torch.Tensor |
||||
v_cache: torch.Tensor |
||||
block_tables: torch.Tensor |
||||
block_size: int |
||||
kv_seq_len: int = None |
||||
sequence_lengths: torch.Tensor = None |
||||
cu_seqlens: torch.Tensor = None |
||||
sm_scale: int = None |
||||
alibi_slopes: torch.Tensor = None |
||||
output_tensor: torch.Tensor = None |
||||
use_spec_dec: bool = False |
||||
use_alibi_attn: bool = False |
||||
|
||||
|
||||
class AttentionBackend(ABC): |
||||
@abstractmethod |
||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs): |
||||
raise NotImplementedError |
||||
|
||||
@abstractmethod |
||||
def decode(self, attn_metadatas: AttentionMetaData, **kwargs): |
||||
raise NotImplementedError |
||||
|
||||
|
||||
class CudaAttentionBackend(AttentionBackend): |
||||
""" |
||||
Attention backend when use_cuda_kernel is True but flash-attn not found. If flash-attn is not found, |
||||
it uses Triton op `context_attention_unpadded` for prefilling and our cuda op `flash_decoding_attention` for decoding. |
||||
""" |
||||
|
||||
def __init__(self, use_flash_attn: bool = False): |
||||
super().__init__() |
||||
self.inference_ops = InferenceOpsLoader().load() |
||||
self.use_flash_attn = use_flash_attn |
||||
|
||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs): |
||||
if self.use_flash_attn: |
||||
token_nums = kwargs.get("token_nums", -1) |
||||
|
||||
from flash_attn import flash_attn_varlen_func |
||||
|
||||
attn_output = flash_attn_varlen_func( |
||||
attn_metadata.query_states, |
||||
attn_metadata.key_states, |
||||
attn_metadata.value_states, |
||||
cu_seqlens_q=attn_metadata.cu_seqlens, |
||||
cu_seqlens_k=attn_metadata.cu_seqlens, |
||||
max_seqlen_q=attn_metadata.kv_seq_len, |
||||
max_seqlen_k=attn_metadata.kv_seq_len, |
||||
dropout_p=0.0, |
||||
softmax_scale=attn_metadata.sm_scale, |
||||
causal=True, |
||||
alibi_slopes=attn_metadata.alibi_slopes, |
||||
) |
||||
attn_output = attn_output.view(token_nums, -1) |
||||
else: |
||||
attn_output = context_attention_unpadded( |
||||
q=attn_metadata.query_states, |
||||
k=attn_metadata.key_states, |
||||
v=attn_metadata.value_states, |
||||
k_cache=attn_metadata.k_cache, |
||||
v_cache=attn_metadata.v_cache, |
||||
context_lengths=attn_metadata.sequence_lengths, |
||||
block_tables=attn_metadata.block_tables, |
||||
block_size=attn_metadata.block_size, |
||||
output=attn_metadata.output_tensor, |
||||
alibi_slopes=attn_metadata.alibi_slopes, |
||||
max_seq_len=attn_metadata.kv_seq_len, |
||||
sm_scale=attn_metadata.sm_scale, |
||||
use_new_kcache_layout=True, # use new k-cache layout |
||||
) |
||||
return attn_output |
||||
|
||||
def decode(self, attn_metadata: AttentionMetaData, **kwargs): |
||||
fd_inter_tensor = kwargs.get("fd_inter_tensor", None) |
||||
output_tensor = attn_metadata.output_tensor |
||||
self.inference_ops.flash_decoding_attention( |
||||
output_tensor, |
||||
attn_metadata.query_states, |
||||
attn_metadata.k_cache, |
||||
attn_metadata.v_cache, |
||||
attn_metadata.sequence_lengths, |
||||
attn_metadata.block_tables, |
||||
attn_metadata.block_size, |
||||
attn_metadata.kv_seq_len, |
||||
fd_inter_tensor.mid_output, |
||||
fd_inter_tensor.exp_sums, |
||||
fd_inter_tensor.max_logits, |
||||
attn_metadata.alibi_slopes, |
||||
attn_metadata.sm_scale, |
||||
) |
||||
return output_tensor |
||||
|
||||
|
||||
class TritonAttentionBackend(AttentionBackend): |
||||
""" |
||||
Attention backend when use_cuda_kernel is False. It uses pure Triton ops for prefilling and decoding. |
||||
""" |
||||
|
||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs): |
||||
return context_attention_unpadded( |
||||
q=attn_metadata.query_states, |
||||
k=attn_metadata.key_states, |
||||
v=attn_metadata.value_states, |
||||
k_cache=attn_metadata.k_cache, |
||||
v_cache=attn_metadata.v_cache, |
||||
context_lengths=attn_metadata.sequence_lengths, |
||||
block_tables=attn_metadata.block_tables, |
||||
block_size=attn_metadata.block_size, |
||||
output=attn_metadata.output_tensor, |
||||
alibi_slopes=attn_metadata.alibi_slopes, |
||||
max_seq_len=attn_metadata.kv_seq_len, |
||||
sm_scale=attn_metadata.sm_scale, |
||||
) |
||||
|
||||
def decode(self, attn_metadata: AttentionMetaData, **kwargs): |
||||
fd_inter_tensor = kwargs.get("fd_inter_tensor", None) |
||||
return flash_decoding_attention( |
||||
q=attn_metadata.query_states, |
||||
k_cache=attn_metadata.k_cache, |
||||
v_cache=attn_metadata.v_cache, |
||||
kv_seq_len=attn_metadata.sequence_lengths, |
||||
block_tables=attn_metadata.block_tables, |
||||
block_size=attn_metadata.block_size, |
||||
max_seq_len_in_batch=attn_metadata.kv_seq_len, |
||||
output=attn_metadata.output_tensor, |
||||
mid_output=fd_inter_tensor.mid_output, |
||||
mid_output_lse=fd_inter_tensor.mid_output_lse, |
||||
alibi_slopes=attn_metadata.alibi_slopes, |
||||
sm_scale=attn_metadata.sm_scale, |
||||
kv_group_num=kwargs.get("num_key_value_groups", 1), |
||||
q_len=kwargs.get("q_len", 1), |
||||
) |
||||
|
||||
|
||||
def get_attention_backend( |
||||
model_shard_infer_config: ModelShardInferenceConfig, |
||||
) -> AttentionBackend: |
||||
""" |
||||
Get the attention backend based on the inference configurations. The modeling will use CUDA-kernel-based backend |
||||
for attention module calculation only when: |
||||
1. using CUDA kernel (use_cuda_kernel=True) |
||||
2. can use flash attention (flash-attn installed and dtype is fp16 or bf16) |
||||
3. not using speculative decoding (currently cuda kernel not support speculative decoding) |
||||
Otherwise, use Triton attention backend. If found flash-attn not installed while `use_cuda_kernel` is True, |
||||
the Triton backend will use a new k cache layout for Triton kernels. |
||||
""" |
||||
# Currently only triton kernels support speculative decoding |
||||
if model_shard_infer_config.use_spec_dec: |
||||
return TritonAttentionBackend() |
||||
|
||||
if model_shard_infer_config.use_cuda_kernel: |
||||
return CudaAttentionBackend(model_shard_infer_config.use_flash_attn) |
||||
|
||||
return TritonAttentionBackend() |
@ -0,0 +1,146 @@
|
||||
from abc import ABC, abstractmethod |
||||
|
||||
from colossalai.inference.config import ModelShardInferenceConfig |
||||
from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData |
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader |
||||
from colossalai.kernel.triton import copy_k_to_blocked_cache, decoding_fused_rotary_embedding, rotary_embedding |
||||
|
||||
|
||||
class PreAttentionBackend(ABC): |
||||
@abstractmethod |
||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs): |
||||
raise NotImplementedError |
||||
|
||||
@abstractmethod |
||||
def decode(self, attn_metadata: AttentionMetaData, **kwargs): |
||||
raise NotImplementedError |
||||
|
||||
|
||||
class CudaPreAttentionBackend(PreAttentionBackend): |
||||
""" |
||||
CudaPreAttentionBackend handles KV cache initialization and positional encoding for CudaAttentionBackend. |
||||
""" |
||||
|
||||
def __init__(self, use_flash_attn: bool): |
||||
super().__init__() |
||||
self.inference_ops = InferenceOpsLoader().load() |
||||
self.use_flash_attn = use_flash_attn |
||||
|
||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs): |
||||
if self.use_flash_attn: |
||||
if not attn_metadata.use_alibi_attn: |
||||
self.inference_ops.rotary_embedding( |
||||
attn_metadata.query_states, |
||||
attn_metadata.key_states, |
||||
kwargs.get("cos", None), |
||||
kwargs.get("sin", None), |
||||
kwargs.get("high_precision", False), |
||||
) |
||||
self.inference_ops.context_kv_cache_memcpy( |
||||
attn_metadata.key_states, |
||||
attn_metadata.value_states, |
||||
attn_metadata.k_cache, |
||||
attn_metadata.v_cache, |
||||
attn_metadata.sequence_lengths, |
||||
attn_metadata.cu_seqlens, |
||||
attn_metadata.block_tables, |
||||
attn_metadata.kv_seq_len, |
||||
) |
||||
elif not attn_metadata.use_alibi_attn: |
||||
rotary_embedding( |
||||
attn_metadata.query_states, |
||||
attn_metadata.key_states, |
||||
kwargs.get("cos", None), |
||||
kwargs.get("sin", None), |
||||
) |
||||
|
||||
def decode(self, attn_metadata: AttentionMetaData, **kwargs): |
||||
if not attn_metadata.use_alibi_attn: |
||||
self.inference_ops.rotary_embedding_and_cache_copy( |
||||
attn_metadata.query_states, |
||||
attn_metadata.key_states, |
||||
attn_metadata.value_states, |
||||
kwargs.get("cos", None), |
||||
kwargs.get("sin", None), |
||||
attn_metadata.k_cache, |
||||
attn_metadata.v_cache, |
||||
attn_metadata.sequence_lengths, |
||||
attn_metadata.block_tables, |
||||
kwargs.get("high_precision", None), |
||||
) |
||||
else: |
||||
self.inference_ops.decode_kv_cache_memcpy( |
||||
attn_metadata.key_states, |
||||
attn_metadata.value_states, |
||||
attn_metadata.k_cache, |
||||
attn_metadata.v_cache, |
||||
attn_metadata.sequence_lengths, |
||||
attn_metadata.block_tables, |
||||
) |
||||
|
||||
|
||||
class TritonPreAttentionBackend(PreAttentionBackend): |
||||
""" |
||||
TritonPreAttentionBackend handles KV cache initialization and positional encoding for TritonAttentionBackend. |
||||
""" |
||||
|
||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs): |
||||
if not attn_metadata.use_alibi_attn: |
||||
rotary_embedding( |
||||
attn_metadata.query_states, |
||||
attn_metadata.key_states, |
||||
kwargs.get("cos", None), |
||||
kwargs.get("sin", None), |
||||
) |
||||
|
||||
def decode(self, attn_metadata: AttentionMetaData, **kwargs): |
||||
if not attn_metadata.use_spec_dec and not attn_metadata.use_alibi_attn: |
||||
decoding_fused_rotary_embedding( |
||||
attn_metadata.query_states, |
||||
attn_metadata.key_states, |
||||
attn_metadata.value_states, |
||||
kwargs.get("cos", None), |
||||
kwargs.get("sin", None), |
||||
attn_metadata.k_cache, |
||||
attn_metadata.v_cache, |
||||
attn_metadata.block_tables, |
||||
attn_metadata.sequence_lengths, |
||||
) |
||||
else: # else if using speculative decoding |
||||
if not attn_metadata.use_alibi_attn: |
||||
rotary_embedding( |
||||
attn_metadata.query_states, |
||||
attn_metadata.key_states, |
||||
kwargs.get("cos", None), |
||||
kwargs.get("sin", None), |
||||
) |
||||
copy_k_to_blocked_cache( |
||||
attn_metadata.key_states, |
||||
attn_metadata.k_cache, |
||||
kv_lengths=attn_metadata.sequence_lengths, |
||||
block_tables=attn_metadata.block_tables, |
||||
n=kwargs.get("q_len", 1), |
||||
) |
||||
copy_k_to_blocked_cache( |
||||
attn_metadata.value_states, |
||||
attn_metadata.v_cache, |
||||
kv_lengths=attn_metadata.sequence_lengths, |
||||
block_tables=attn_metadata.block_tables, |
||||
n=kwargs.get("q_len", 1), |
||||
) |
||||
|
||||
|
||||
def get_pre_attention_backend( |
||||
model_shard_infer_config: ModelShardInferenceConfig, |
||||
) -> PreAttentionBackend: |
||||
""" |
||||
Get the backend for pre-attention computations, including potisional encoding like |
||||
RoPE and KV cache initialization. It adopt the same selection logic as attention_backend/get_attention_backend. |
||||
""" |
||||
if model_shard_infer_config.use_spec_dec: |
||||
return TritonPreAttentionBackend() |
||||
|
||||
if model_shard_infer_config.use_cuda_kernel: |
||||
return CudaPreAttentionBackend(model_shard_infer_config.use_flash_attn) |
||||
|
||||
return TritonPreAttentionBackend() |
@ -0,0 +1,692 @@
|
||||
import math |
||||
import warnings |
||||
from typing import List, Optional, Tuple, Union |
||||
|
||||
import torch |
||||
import torch.utils.checkpoint |
||||
from torch import nn |
||||
from torch.nn import CrossEntropyLoss |
||||
from transformers.cache_utils import Cache, DynamicCache |
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
||||
from transformers.models.cohere.modeling_cohere import ( |
||||
CohereForCausalLM, |
||||
CohereModel, |
||||
StaticCache, |
||||
apply_rotary_pos_emb, |
||||
repeat_kv, |
||||
) |
||||
from transformers.utils import logging |
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager |
||||
from colossalai.shardformer.layer._operation import ( |
||||
all_to_all_comm, |
||||
gather_forward_split_backward, |
||||
split_forward_gather_backward, |
||||
) |
||||
from colossalai.shardformer.shard import ShardConfig |
||||
|
||||
from ..layer import ColoAttention, cross_entropy_1d |
||||
|
||||
|
||||
class CommandPipelineForwards: |
||||
""" |
||||
This class serves as a micro library for forward function substitution of Command models |
||||
under pipeline setting. |
||||
""" |
||||
|
||||
@staticmethod |
||||
def command_model_forward( |
||||
self: CohereModel, |
||||
input_ids: torch.LongTensor = None, |
||||
attention_mask: Optional[torch.Tensor] = None, |
||||
position_ids: Optional[torch.LongTensor] = None, |
||||
past_key_values: Optional[List[torch.FloatTensor]] = None, |
||||
inputs_embeds: Optional[torch.FloatTensor] = None, |
||||
use_cache: Optional[bool] = None, |
||||
output_attentions: Optional[bool] = None, |
||||
output_hidden_states: Optional[bool] = None, |
||||
return_dict: Optional[bool] = None, |
||||
cache_position: Optional[torch.LongTensor] = None, |
||||
stage_manager: Optional[PipelineStageManager] = None, |
||||
hidden_states: Optional[torch.FloatTensor] = None, |
||||
stage_index: Optional[List[int]] = None, |
||||
shard_config: ShardConfig = None, |
||||
): |
||||
logger = logging.get_logger(__name__) |
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
||||
output_hidden_states = ( |
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
||||
) |
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache |
||||
|
||||
if use_cache: |
||||
logger.warning_once( |
||||
"`use_cache=True` is incompatible with pipeline parallelism. Setting `use_cache=False`..." |
||||
) |
||||
use_cache = False |
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
||||
|
||||
# retrieve input_ids and inputs_embeds |
||||
if stage_manager.is_first_stage(): |
||||
if input_ids is not None and inputs_embeds is not None: |
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") |
||||
elif input_ids is not None: |
||||
batch_size, seq_length = input_ids.shape[:2] |
||||
elif inputs_embeds is not None: |
||||
batch_size, seq_length, _ = inputs_embeds.shape[:2] |
||||
else: |
||||
raise ValueError("You have to specify either input_ids or inputs_embeds") |
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device |
||||
if inputs_embeds is None: |
||||
inputs_embeds = self.embed_tokens(input_ids) |
||||
hidden_states = inputs_embeds |
||||
else: |
||||
input_shape = hidden_states.shape[:-1] |
||||
batch_size, seq_length = input_shape |
||||
device = hidden_states.device |
||||
|
||||
past_seen_tokens = 0 |
||||
if use_cache: # kept for BC (cache positions) |
||||
if not isinstance(past_key_values, StaticCache): |
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values) |
||||
past_seen_tokens = past_key_values.get_seq_length() |
||||
if cache_position is None: |
||||
if isinstance(past_key_values, StaticCache): |
||||
raise ValueError("cache_position is a required argument when using StaticCache.") |
||||
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=device) |
||||
|
||||
seq_length_with_past = seq_length + past_seen_tokens |
||||
|
||||
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. |
||||
if output_attentions: |
||||
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") |
||||
output_attentions = False |
||||
if output_hidden_states: |
||||
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") |
||||
output_hidden_states = False |
||||
if use_cache: |
||||
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") |
||||
use_cache = False |
||||
|
||||
if position_ids is None: |
||||
position_ids = cache_position.unsqueeze(0) |
||||
|
||||
# embed positions, for the first stage, hidden_states is the input embeddings, |
||||
# for the other stages, hidden_states is the output of the previous stage |
||||
if shard_config.enable_flash_attention: |
||||
# in this case, attention_mask is a dict rather than a tensor |
||||
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) |
||||
attention_mask = ColoAttention.prepare_attn_kwargs( |
||||
mask_shape, |
||||
hidden_states.dtype, |
||||
hidden_states.device, |
||||
q_padding_mask=attention_mask, |
||||
is_causal=True, |
||||
) |
||||
else: |
||||
attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position) |
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache: |
||||
if use_cache: |
||||
logger.warning_once( |
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
||||
) |
||||
use_cache = False |
||||
|
||||
# decoder layers |
||||
all_hidden_states = () if output_hidden_states else None |
||||
all_self_attns = () if output_attentions else None |
||||
next_decoder_cache = None |
||||
|
||||
start_idx, end_idx = stage_index[0], stage_index[1] |
||||
num_ckpt_layers = 0 |
||||
if self.gradient_checkpointing and self.training: |
||||
num_ckpt_layers = end_idx - start_idx |
||||
# TODO: We can replace `gradient_checkpointing_enable` fn and initialize a gradient_checkpointing (List[bool]) for each layer |
||||
if shard_config.gradient_checkpoint_config is not None: |
||||
num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers( |
||||
stage=stage_manager.stage, |
||||
num_stages=stage_manager.num_stages, |
||||
num_layers=end_idx - start_idx, |
||||
model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0), |
||||
num_model_chunks=stage_manager.num_model_chunks, |
||||
) |
||||
assert num_ckpt_layers <= end_idx - start_idx |
||||
|
||||
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): |
||||
if output_hidden_states: |
||||
all_hidden_states += (hidden_states,) |
||||
|
||||
if idx - start_idx < num_ckpt_layers: |
||||
layer_outputs = self._gradient_checkpointing_func( |
||||
decoder_layer.__call__, |
||||
hidden_states, |
||||
attention_mask, |
||||
position_ids, |
||||
past_key_values, |
||||
output_attentions, |
||||
use_cache, |
||||
cache_position, |
||||
) |
||||
else: |
||||
layer_outputs = decoder_layer( |
||||
hidden_states, |
||||
attention_mask=attention_mask, |
||||
position_ids=position_ids, |
||||
past_key_value=past_key_values, |
||||
output_attentions=output_attentions, |
||||
use_cache=use_cache, |
||||
cache_position=cache_position, |
||||
) |
||||
|
||||
hidden_states = layer_outputs[0] |
||||
|
||||
if use_cache: |
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1] |
||||
if output_attentions: |
||||
all_self_attns += (layer_outputs[1],) |
||||
|
||||
if stage_manager.is_last_stage(): |
||||
hidden_states = self.norm(hidden_states) |
||||
|
||||
# add hidden states from the last decoder layer |
||||
if output_hidden_states: |
||||
all_hidden_states += (hidden_states,) |
||||
next_cache = next_decoder_cache if use_cache else None |
||||
if stage_manager.is_last_stage(): |
||||
if not return_dict: |
||||
return tuple( |
||||
v |
||||
for v in [ |
||||
hidden_states, |
||||
next_cache, |
||||
all_hidden_states, |
||||
all_self_attns, |
||||
] |
||||
if v is not None |
||||
) |
||||
return BaseModelOutputWithPast( |
||||
last_hidden_state=hidden_states, |
||||
past_key_values=next_cache, |
||||
hidden_states=all_hidden_states, |
||||
attentions=all_self_attns, |
||||
) |
||||
# always return dict for imediate stage |
||||
return {"hidden_states": hidden_states} |
||||
|
||||
@staticmethod |
||||
def command_for_causal_lm_forward( |
||||
self: CohereForCausalLM, |
||||
input_ids: torch.LongTensor = None, |
||||
attention_mask: Optional[torch.Tensor] = None, |
||||
position_ids: Optional[torch.LongTensor] = None, |
||||
past_key_values: Optional[List[torch.FloatTensor]] = None, |
||||
inputs_embeds: Optional[torch.FloatTensor] = None, |
||||
labels: Optional[torch.LongTensor] = None, |
||||
use_cache: Optional[bool] = None, |
||||
output_attentions: Optional[bool] = None, |
||||
output_hidden_states: Optional[bool] = None, |
||||
return_dict: Optional[bool] = None, |
||||
cache_position: Optional[torch.LongTensor] = None, |
||||
stage_manager: Optional[PipelineStageManager] = None, |
||||
hidden_states: Optional[torch.FloatTensor] = None, |
||||
stage_index: Optional[List[int]] = None, |
||||
shard_config: ShardConfig = None, |
||||
): |
||||
r""" |
||||
Args: |
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
||||
|
||||
Returns: |
||||
|
||||
Example: |
||||
|
||||
```python |
||||
>>> from transformers import AutoTokenizer, CohereForCausalLM |
||||
|
||||
>>> model = CohereForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) |
||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) |
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?" |
||||
>>> inputs = tokenizer(prompt, return_tensors="pt") |
||||
|
||||
>>> # Generate |
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." |
||||
```""" |
||||
logger = logging.get_logger(__name__) |
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
||||
output_hidden_states = ( |
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
||||
) |
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
||||
|
||||
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. |
||||
if output_attentions: |
||||
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") |
||||
output_attentions = False |
||||
if output_hidden_states: |
||||
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") |
||||
output_hidden_states = False |
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) |
||||
outputs = CommandPipelineForwards.command_model_forward( |
||||
self.model, |
||||
input_ids=input_ids, |
||||
attention_mask=attention_mask, |
||||
position_ids=position_ids, |
||||
past_key_values=past_key_values, |
||||
inputs_embeds=inputs_embeds, |
||||
use_cache=use_cache, |
||||
output_attentions=output_attentions, |
||||
output_hidden_states=output_hidden_states, |
||||
return_dict=return_dict, |
||||
cache_position=cache_position, |
||||
stage_manager=stage_manager, |
||||
hidden_states=hidden_states, |
||||
stage_index=stage_index, |
||||
shard_config=shard_config, |
||||
) |
||||
past_key_values = None |
||||
|
||||
if stage_manager.is_last_stage(): |
||||
hidden_states = outputs[0] |
||||
logits = self.lm_head(hidden_states) |
||||
logits = logits * self.logit_scale |
||||
logits = logits.float() |
||||
loss = None |
||||
if labels is not None: |
||||
# Shift so that tokens < n predict n |
||||
shift_logits = logits[..., :-1, :].contiguous() |
||||
shift_labels = labels[..., 1:].contiguous() |
||||
# Flatten the tokens |
||||
loss_fct = CrossEntropyLoss() |
||||
shift_labels = shift_labels.view(-1) |
||||
# Enable model parallelism |
||||
shift_labels = shift_labels.to(shift_logits.device) |
||||
if shard_config.enable_tensor_parallelism and shard_config.parallel_output: |
||||
new_vocab_size = logits.shape[-1] |
||||
shift_logits = shift_logits.view(-1, new_vocab_size) |
||||
loss = cross_entropy_1d( |
||||
shift_logits, |
||||
shift_labels, |
||||
process_group=shard_config.tensor_parallel_process_group, |
||||
vocab_size=self.lm_head.out_features, |
||||
dtype=self.model.dtype, |
||||
) |
||||
else: |
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size) |
||||
loss = loss_fct(shift_logits, shift_labels) |
||||
|
||||
if not return_dict: |
||||
output = (logits,) + outputs[1:] |
||||
return (loss,) + output if loss is not None else output |
||||
|
||||
return CausalLMOutputWithPast( |
||||
loss=loss, |
||||
logits=logits, |
||||
past_key_values=outputs.past_key_values, |
||||
hidden_states=outputs.hidden_states, |
||||
attentions=outputs.attentions, |
||||
) |
||||
else: |
||||
hidden_states = outputs.get("hidden_states") |
||||
return {"hidden_states": hidden_states} |
||||
|
||||
|
||||
def get_command_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): |
||||
def forward( |
||||
self, |
||||
hidden_states: torch.Tensor, |
||||
attention_mask: Optional[torch.Tensor] = None, |
||||
position_ids: Optional[torch.LongTensor] = None, |
||||
past_key_value: Optional[Cache] = None, |
||||
output_attentions: bool = False, |
||||
use_cache: bool = False, |
||||
cache_position: Optional[torch.LongTensor] = None, |
||||
**kwargs, |
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: |
||||
if sp_mode is not None: |
||||
assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode" |
||||
assert (sp_size is not None) and ( |
||||
sp_group is not None |
||||
), "Must specify sp_size and sp_group for sequence parallel" |
||||
if "padding_mask" in kwargs: |
||||
warnings.warn( |
||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" |
||||
) |
||||
|
||||
bsz, q_len, _ = hidden_states.size() |
||||
# sp: modify sp_len when sequence parallel mode is ring |
||||
if sp_mode in ["split_gather", "ring"]: |
||||
q_len *= sp_size |
||||
|
||||
query_states = self.q_proj(hidden_states) |
||||
key_states = self.k_proj(hidden_states) |
||||
value_states = self.v_proj(hidden_states) |
||||
|
||||
# sp: all-to-all comminucation when introducing sequence parallel |
||||
if sp_mode == "all_to_all": |
||||
query_states = all_to_all_comm(query_states, sp_group) |
||||
key_states = all_to_all_comm(key_states, sp_group) |
||||
value_states = all_to_all_comm(value_states, sp_group) |
||||
bsz, q_len, _ = query_states.size() |
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
||||
|
||||
kv_seq_len = key_states.shape[-2] |
||||
if past_key_value is not None: |
||||
if self.layer_idx is None: |
||||
raise ValueError( |
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " |
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " |
||||
"with a layer index." |
||||
) |
||||
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) |
||||
|
||||
cos, sin = self.rotary_emb(value_states, position_ids) |
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
||||
|
||||
if past_key_value is not None: |
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} |
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) |
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads |
||||
key_states = repeat_kv(key_states, self.num_key_value_groups) |
||||
value_states = repeat_kv(value_states, self.num_key_value_groups) |
||||
|
||||
if shard_config.enable_flash_attention: |
||||
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." |
||||
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) |
||||
else: |
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) |
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): |
||||
raise ValueError( |
||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" |
||||
f" {attn_weights.size()}" |
||||
) |
||||
|
||||
if attention_mask is not None: |
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): |
||||
raise ValueError( |
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" |
||||
) |
||||
attn_weights = attn_weights + attention_mask |
||||
|
||||
# upcast attention to fp32 |
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) |
||||
attn_output = torch.matmul(attn_weights, value_states) |
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): |
||||
raise ValueError( |
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" |
||||
f" {attn_output.size()}" |
||||
) |
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous() |
||||
# sp: all-to-all comminucation when introducing sequence parallel |
||||
if sp_mode == "all_to_all": |
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) |
||||
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) |
||||
else: |
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) |
||||
|
||||
attn_output = self.o_proj(attn_output) |
||||
|
||||
if not output_attentions: |
||||
attn_weights = None |
||||
return attn_output, attn_weights, past_key_value |
||||
|
||||
return forward |
||||
|
||||
|
||||
def get_command_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): |
||||
logger = logging.get_logger(__name__) |
||||
|
||||
def forward( |
||||
self, |
||||
input_ids: torch.LongTensor = None, |
||||
attention_mask: Optional[torch.Tensor] = None, |
||||
position_ids: Optional[torch.LongTensor] = None, |
||||
past_key_values: Optional[List[torch.FloatTensor]] = None, |
||||
inputs_embeds: Optional[torch.FloatTensor] = None, |
||||
use_cache: Optional[bool] = None, |
||||
output_attentions: Optional[bool] = None, |
||||
output_hidden_states: Optional[bool] = None, |
||||
return_dict: Optional[bool] = None, |
||||
cache_position: Optional[torch.LongTensor] = None, |
||||
) -> Union[Tuple, BaseModelOutputWithPast]: |
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
||||
output_hidden_states = ( |
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
||||
) |
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache |
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
||||
|
||||
# retrieve input_ids and inputs_embeds |
||||
if (input_ids is None) ^ (inputs_embeds is not None): |
||||
raise ValueError( |
||||
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" |
||||
) |
||||
|
||||
if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training: |
||||
if use_cache: |
||||
logger.warning_once( |
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
||||
) |
||||
use_cache = False |
||||
|
||||
if inputs_embeds is None: |
||||
inputs_embeds = self.embed_tokens(input_ids) |
||||
|
||||
past_seen_tokens = 0 |
||||
seq_len = inputs_embeds.shape[1] |
||||
if use_cache: # kept for BC (cache positions) |
||||
if not isinstance(past_key_values, StaticCache): |
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values) |
||||
past_seen_tokens = past_key_values.get_seq_length() |
||||
if cache_position is None: |
||||
if isinstance(past_key_values, StaticCache): |
||||
raise ValueError("cache_position is a required argument when using StaticCache.") |
||||
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_len, device=inputs_embeds.device) |
||||
|
||||
if position_ids is None: |
||||
position_ids = cache_position.unsqueeze(0) |
||||
|
||||
# in this case, attention_mask is a dict rather than a tensor |
||||
if shard_config.enable_flash_attention: |
||||
mask_shape = (inputs_embeds.shape[0], 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len) |
||||
attention_mask = ColoAttention.prepare_attn_kwargs( |
||||
mask_shape, |
||||
inputs_embeds.dtype, |
||||
inputs_embeds.device, |
||||
q_padding_mask=attention_mask, |
||||
is_causal=True, |
||||
) |
||||
else: |
||||
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) |
||||
|
||||
if sp_mode in ["ring", "split_gather"]: |
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) |
||||
elif sp_mode == "all_to_all": |
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) |
||||
hidden_states = inputs_embeds |
||||
|
||||
# decoder layers |
||||
all_hidden_states = () if output_hidden_states else None |
||||
all_self_attns = () if output_attentions else None |
||||
next_decoder_cache = None |
||||
|
||||
for decoder_layer in self.layers: |
||||
if output_hidden_states: |
||||
all_hidden_states += (hidden_states,) |
||||
if self.gradient_checkpointing and self.training: |
||||
layer_outputs = self._gradient_checkpointing_func( |
||||
decoder_layer.__call__, |
||||
hidden_states, |
||||
attention_mask, |
||||
position_ids, |
||||
past_key_values, |
||||
output_attentions, |
||||
use_cache, |
||||
cache_position, |
||||
) |
||||
|
||||
else: |
||||
layer_outputs = decoder_layer( |
||||
hidden_states, |
||||
attention_mask=attention_mask, |
||||
position_ids=position_ids, |
||||
past_key_value=past_key_values, |
||||
output_attentions=output_attentions, |
||||
use_cache=use_cache, |
||||
cache_position=cache_position, |
||||
) |
||||
|
||||
hidden_states = layer_outputs[0] |
||||
|
||||
if use_cache: |
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1] |
||||
|
||||
if output_attentions: |
||||
all_self_attns += (layer_outputs[1],) |
||||
|
||||
hidden_states = self.norm(hidden_states) |
||||
|
||||
if sp_mode == "ring" or sp_mode == "split_gather": |
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) |
||||
elif sp_mode == "all_to_all": |
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) |
||||
|
||||
# add hidden states from the last decoder layer |
||||
if output_hidden_states: |
||||
all_hidden_states += (hidden_states,) |
||||
|
||||
next_cache = None |
||||
if use_cache: |
||||
next_cache = ( |
||||
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache |
||||
) |
||||
if not return_dict: |
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) |
||||
|
||||
return BaseModelOutputWithPast( |
||||
last_hidden_state=hidden_states, |
||||
past_key_values=next_cache, |
||||
hidden_states=all_hidden_states, |
||||
attentions=all_self_attns, |
||||
) |
||||
|
||||
return forward |
||||
|
||||
|
||||
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): |
||||
from transformers import CohereForCausalLM |
||||
|
||||
def forward( |
||||
self: CohereForCausalLM, |
||||
input_ids: torch.LongTensor = None, |
||||
attention_mask: Optional[torch.Tensor] = None, |
||||
position_ids: Optional[torch.LongTensor] = None, |
||||
past_key_values: Optional[List[torch.FloatTensor]] = None, |
||||
inputs_embeds: Optional[torch.FloatTensor] = None, |
||||
labels: Optional[torch.LongTensor] = None, |
||||
use_cache: Optional[bool] = None, |
||||
output_attentions: Optional[bool] = None, |
||||
output_hidden_states: Optional[bool] = None, |
||||
return_dict: Optional[bool] = None, |
||||
cache_position: Optional[torch.LongTensor] = None, |
||||
) -> Union[Tuple, CausalLMOutputWithPast]: |
||||
r""" |
||||
Args: |
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
||||
|
||||
Returns: |
||||
|
||||
Example: |
||||
|
||||
```python |
||||
>>> from transformers import AutoTokenizer, CohereForCausalLM |
||||
|
||||
>>> model = CohereForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) |
||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) |
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?" |
||||
>>> inputs = tokenizer(prompt, return_tensors="pt") |
||||
|
||||
>>> # Generate |
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." |
||||
```""" |
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
||||
output_hidden_states = ( |
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
||||
) |
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) |
||||
outputs = self.model( |
||||
input_ids=input_ids, |
||||
attention_mask=attention_mask, |
||||
position_ids=position_ids, |
||||
past_key_values=past_key_values, |
||||
inputs_embeds=inputs_embeds, |
||||
use_cache=use_cache, |
||||
output_attentions=output_attentions, |
||||
output_hidden_states=output_hidden_states, |
||||
return_dict=return_dict, |
||||
cache_position=cache_position, |
||||
) |
||||
|
||||
hidden_states = outputs[0] |
||||
|
||||
logits = self.lm_head(hidden_states) |
||||
logits = logits * self.logit_scale |
||||
logits = logits.float() |
||||
|
||||
loss = None |
||||
if labels is not None: |
||||
# Shift so that tokens < n predict n |
||||
shift_logits = logits[..., :-1, :].contiguous() |
||||
shift_labels = labels[..., 1:].contiguous() |
||||
shift_labels = shift_labels.view(-1) |
||||
# Enable model parallelism |
||||
shift_labels = shift_labels.to(shift_logits.device) |
||||
new_vocab_size = logits.shape[-1] |
||||
shift_logits = shift_logits.view(-1, new_vocab_size) |
||||
loss = cross_entropy_1d( |
||||
shift_logits, |
||||
shift_labels, |
||||
process_group=shard_config.tensor_parallel_process_group, |
||||
vocab_size=self.lm_head.out_features, |
||||
dtype=self.model.dtype, |
||||
) |
||||
|
||||
if not return_dict: |
||||
output = (logits,) + outputs[1:] |
||||
return (loss,) + output if loss is not None else output |
||||
|
||||
return CausalLMOutputWithPast( |
||||
loss=loss, |
||||
logits=logits, |
||||
past_key_values=outputs.past_key_values, |
||||
hidden_states=outputs.hidden_states, |
||||
attentions=outputs.attentions, |
||||
) |
||||
|
||||
return forward |
@ -0,0 +1,369 @@
|
||||
import warnings |
||||
from functools import partial |
||||
from typing import Callable, Dict, List, Union |
||||
|
||||
import torch.nn as nn |
||||
from torch import Tensor |
||||
from torch.nn import Module |
||||
|
||||
from colossalai.shardformer.layer import ( |
||||
FusedLayerNorm, |
||||
LayerNorm, |
||||
Linear1D_Col, |
||||
Linear1D_Row, |
||||
PaddingEmbedding, |
||||
PaddingLMHead, |
||||
VocabParallelEmbedding1D, |
||||
VocabParallelLMHead1D, |
||||
) |
||||
|
||||
from ..modeling.command import ( |
||||
CommandPipelineForwards, |
||||
get_command_flash_attention_forward, |
||||
get_command_flash_attention_model_forward, |
||||
get_lm_forward_with_dist_cross_entropy, |
||||
) |
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription |
||||
|
||||
__all__ = ["CommandPolicy", "CommandForCausalLMPolicy"] |
||||
|
||||
|
||||
class CommandPolicy(Policy): |
||||
def config_sanity_check(self): |
||||
pass |
||||
|
||||
def preprocess(self): |
||||
self.tie_weight = self.tie_weight_check() |
||||
self.origin_attn_implement = self.model.config._attn_implementation |
||||
return self.model |
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: |
||||
from transformers.models.cohere.modeling_cohere import ( |
||||
CohereAttention, |
||||
CohereDecoderLayer, |
||||
CohereFlashAttention2, |
||||
CohereModel, |
||||
CohereSdpaAttention, |
||||
) |
||||
|
||||
ATTN_IMPLEMENTATION = { |
||||
"eager": CohereAttention, |
||||
"flash_attention_2": CohereFlashAttention2, |
||||
"sdpa": CohereSdpaAttention, |
||||
} |
||||
policy = {} |
||||
|
||||
attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] |
||||
embedding_cls = None |
||||
if self.shard_config.enable_tensor_parallelism: |
||||
embedding_cls = VocabParallelEmbedding1D |
||||
else: |
||||
if self.tie_weight: |
||||
embedding_cls = PaddingEmbedding |
||||
|
||||
if self.shard_config.enable_fused_normalization: |
||||
norm_cls = FusedLayerNorm |
||||
else: |
||||
norm_cls = LayerNorm |
||||
|
||||
if self.pipeline_stage_manager is not None: |
||||
self.shard_config.enable_sequence_parallelism = False |
||||
self.shard_config.enable_sequence_overlap = False |
||||
self.shard_config.sequence_parallelism_mode = None |
||||
warnings.warn( |
||||
f"For Command, sequence parallelism is currently not compatible with pipeline parallelism, set to be False" |
||||
) |
||||
sp_mode = self.shard_config.sequence_parallelism_mode or None |
||||
sp_size = self.shard_config.sequence_parallel_size or None |
||||
sp_group = self.shard_config.sequence_parallel_process_group or None |
||||
sp_partial_derived = sp_mode in ["split_gather", "ring"] |
||||
|
||||
if sp_mode == "all_to_all": |
||||
decoder_attribute_replacement = { |
||||
"num_heads": self.model.config.num_attention_heads // sp_size, |
||||
} |
||||
if getattr(self.model.config, "num_key_value_heads", False): |
||||
decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size |
||||
|
||||
policy[attn_cls] = ModulePolicyDescription( |
||||
attribute_replacement=decoder_attribute_replacement, |
||||
) |
||||
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism: |
||||
self.append_or_create_method_replacement( |
||||
description={ |
||||
"forward": get_command_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group), |
||||
}, |
||||
policy=policy, |
||||
target_key=attn_cls, |
||||
) |
||||
if self.pipeline_stage_manager is None: |
||||
self.append_or_create_method_replacement( |
||||
description={ |
||||
"forward": get_command_flash_attention_model_forward( |
||||
self.shard_config, |
||||
sp_mode=sp_mode, |
||||
sp_size=sp_size, |
||||
sp_group=sp_group, |
||||
), |
||||
}, |
||||
policy=policy, |
||||
target_key=CohereModel, |
||||
) |
||||
|
||||
if self.shard_config.enable_tensor_parallelism: |
||||
assert ( |
||||
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 |
||||
), f"The number of attention heads must be divisible by tensor parallel size." |
||||
if hasattr(self.model.config, "num_key_value_heads"): |
||||
assert ( |
||||
self.model.config.num_key_value_heads >= self.shard_config.tensor_parallel_size |
||||
and self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0 |
||||
), f"The number of key_value heads must be divisible by, and must not be less than tensor parallel size." |
||||
decoder_attribute_replacement = { |
||||
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, |
||||
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, |
||||
} |
||||
if getattr(self.model.config, "num_key_value_heads", False): |
||||
decoder_attribute_replacement["self_attn.num_key_value_heads"] = ( |
||||
self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size |
||||
) |
||||
|
||||
policy[CohereDecoderLayer] = ModulePolicyDescription( |
||||
attribute_replacement=decoder_attribute_replacement, |
||||
sub_module_replacement=[ |
||||
SubModuleReplacementDescription( |
||||
suffix="self_attn.q_proj", |
||||
target_module=Linear1D_Col, |
||||
kwargs=dict(seq_parallel_mode=sp_mode), |
||||
), |
||||
SubModuleReplacementDescription( |
||||
suffix="self_attn.k_proj", |
||||
target_module=Linear1D_Col, |
||||
kwargs=dict(seq_parallel_mode=sp_mode), |
||||
), |
||||
SubModuleReplacementDescription( |
||||
suffix="self_attn.v_proj", |
||||
target_module=Linear1D_Col, |
||||
kwargs=dict(seq_parallel_mode=sp_mode), |
||||
), |
||||
SubModuleReplacementDescription( |
||||
suffix="self_attn.o_proj", |
||||
target_module=Linear1D_Row, |
||||
kwargs=dict(seq_parallel_mode=sp_mode), |
||||
), |
||||
SubModuleReplacementDescription( |
||||
suffix="mlp.gate_proj", |
||||
target_module=Linear1D_Col, |
||||
kwargs=dict(seq_parallel_mode=sp_mode), |
||||
), |
||||
SubModuleReplacementDescription( |
||||
suffix="mlp.up_proj", |
||||
target_module=Linear1D_Col, |
||||
kwargs=dict(seq_parallel_mode=sp_mode), |
||||
), |
||||
SubModuleReplacementDescription( |
||||
suffix="mlp.down_proj", |
||||
target_module=Linear1D_Row, |
||||
kwargs=dict(seq_parallel_mode=sp_mode), |
||||
), |
||||
], |
||||
) |
||||
|
||||
if embedding_cls is not None: |
||||
self.append_or_create_submodule_replacement( |
||||
description=SubModuleReplacementDescription( |
||||
suffix="embed_tokens", |
||||
target_module=embedding_cls, |
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, |
||||
), |
||||
policy=policy, |
||||
target_key=CohereModel, |
||||
) |
||||
|
||||
# optimization configuration |
||||
self.append_or_create_submodule_replacement( |
||||
description=[ |
||||
SubModuleReplacementDescription( |
||||
suffix="input_layernorm", |
||||
target_module=norm_cls, |
||||
kwargs={"sp_partial_derived": sp_partial_derived}, |
||||
), |
||||
], |
||||
policy=policy, |
||||
target_key=CohereDecoderLayer, |
||||
) |
||||
|
||||
self.append_or_create_submodule_replacement( |
||||
description=SubModuleReplacementDescription( |
||||
suffix="norm", |
||||
target_module=norm_cls, |
||||
kwargs={"sp_partial_derived": sp_partial_derived}, |
||||
), |
||||
policy=policy, |
||||
target_key=CohereModel, |
||||
) |
||||
|
||||
return policy |
||||
|
||||
def postprocess(self): |
||||
return self.model |
||||
|
||||
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: |
||||
"""If under pipeline parallel setting, replacing the original forward method of huggingface |
||||
to customized forward method, and add this changing to policy.""" |
||||
if self.pipeline_stage_manager is None: |
||||
return |
||||
|
||||
stage_manager = self.pipeline_stage_manager |
||||
if self.model.__class__.__name__ == "CohereModel": |
||||
module = self.model |
||||
else: |
||||
module = self.model.model |
||||
|
||||
if stage_manager.is_interleave: |
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers)) |
||||
stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage) |
||||
method_replacement = { |
||||
"forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config) |
||||
} |
||||
|
||||
else: |
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers)) |
||||
stage_index = stage_manager.get_stage_index(layers_per_stage) |
||||
method_replacement = { |
||||
"forward": partial( |
||||
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config |
||||
) |
||||
} |
||||
|
||||
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) |
||||
|
||||
def get_held_layers(self) -> List[Module]: |
||||
"""Get pipeline layers for current stage.""" |
||||
assert self.pipeline_stage_manager is not None |
||||
|
||||
if self.model.__class__.__name__ == "CohereModel": |
||||
module = self.model |
||||
else: |
||||
module = self.model.model |
||||
stage_manager = self.pipeline_stage_manager |
||||
|
||||
held_layers = [] |
||||
if stage_manager.is_interleave: |
||||
assert stage_manager.num_model_chunks is not None |
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers)) |
||||
stage_indices = stage_manager.get_stage_index(layers_per_stage) |
||||
if stage_manager.is_first_stage(ignore_chunk=True): |
||||
held_layers.append(module.embed_tokens) |
||||
for start_idx, end_idx in stage_indices: |
||||
held_layers.extend(module.layers[start_idx:end_idx]) |
||||
if stage_manager.is_last_stage(ignore_chunk=True): |
||||
held_layers.append(module.norm) |
||||
|
||||
else: |
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers)) |
||||
if stage_manager.is_first_stage(): |
||||
held_layers.append(module.embed_tokens) |
||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) |
||||
held_layers.extend(module.layers[start_idx:end_idx]) |
||||
if stage_manager.is_last_stage(): |
||||
held_layers.append(module.norm) |
||||
|
||||
return held_layers |
||||
|
||||
|
||||
class CommandModelPolicy(CommandPolicy): |
||||
def module_policy(self): |
||||
policy = super().module_policy() |
||||
from transformers.models.cohere.modeling_cohere import CohereModel |
||||
|
||||
if self.pipeline_stage_manager: |
||||
# set None as default |
||||
self.set_pipeline_forward( |
||||
model_cls=CohereModel, new_forward=CommandPipelineForwards.command_model_forward, policy=policy |
||||
) |
||||
return policy |
||||
|
||||
def get_held_layers(self) -> List[Module]: |
||||
"""Get pipeline layers for current stage.""" |
||||
held_layers = super().get_held_layers() |
||||
return held_layers |
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]: |
||||
"""No shared params in command model""" |
||||
return [] |
||||
|
||||
|
||||
class CommandForCausalLMPolicy(CommandPolicy): |
||||
def module_policy(self): |
||||
from transformers import CohereForCausalLM |
||||
|
||||
policy = super().module_policy() |
||||
|
||||
if self.shard_config.enable_tensor_parallelism: |
||||
# add a new item for casual lm |
||||
new_item = { |
||||
CohereForCausalLM: ModulePolicyDescription( |
||||
sub_module_replacement=[ |
||||
SubModuleReplacementDescription( |
||||
suffix="lm_head", |
||||
target_module=VocabParallelLMHead1D, |
||||
kwargs={ |
||||
"gather_output": not self.shard_config.parallel_output, |
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, |
||||
}, |
||||
) |
||||
], |
||||
) |
||||
} |
||||
if self.shard_config.parallel_output: |
||||
new_item[CohereForCausalLM].method_replacement = { |
||||
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config) |
||||
} |
||||
else: |
||||
new_item = { |
||||
CohereForCausalLM: ModulePolicyDescription( |
||||
sub_module_replacement=[ |
||||
SubModuleReplacementDescription( |
||||
suffix="lm_head", |
||||
target_module=PaddingLMHead, |
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, |
||||
) |
||||
], |
||||
) |
||||
} |
||||
policy.update(new_item) |
||||
|
||||
if self.pipeline_stage_manager: |
||||
# set None as default |
||||
self.set_pipeline_forward( |
||||
model_cls=CohereForCausalLM, |
||||
new_forward=CommandPipelineForwards.command_for_causal_lm_forward, |
||||
policy=policy, |
||||
) |
||||
|
||||
return policy |
||||
|
||||
def get_held_layers(self) -> List[Module]: |
||||
"""Get pipeline layers for current stage.""" |
||||
stage_manager = self.pipeline_stage_manager |
||||
held_layers = super().get_held_layers() |
||||
if stage_manager.is_last_stage(ignore_chunk=True): |
||||
held_layers.append(self.model.lm_head) |
||||
return held_layers |
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]: |
||||
command_model = self.model.model |
||||
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: |
||||
if ( |
||||
id(command_model.embed_tokens.weight) == id(self.model.lm_head.weight) |
||||
and self.pipeline_stage_manager.num_stages > 1 |
||||
): |
||||
# tie weights |
||||
return [ |
||||
{ |
||||
0: command_model.embed_tokens.weight, |
||||
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, |
||||
} |
||||
] |
||||
return [] |
@ -0,0 +1,79 @@
|
||||
import torch |
||||
import transformers |
||||
|
||||
from ..registry import ModelAttribute, model_zoo |
||||
|
||||
try: |
||||
from transformers import CohereConfig |
||||
|
||||
HAS_COMMAND = True |
||||
except ImportError: |
||||
HAS_COMMAND = False |
||||
|
||||
if HAS_COMMAND: |
||||
# =============================== |
||||
# Register Command-R |
||||
# =============================== |
||||
|
||||
def data_gen(): |
||||
input_ids = torch.Tensor( |
||||
[ |
||||
[1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082], |
||||
[1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082], |
||||
] |
||||
).long() |
||||
|
||||
attention_mask = torch.Tensor( |
||||
[ |
||||
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], |
||||
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], |
||||
] |
||||
).long() |
||||
|
||||
return dict(input_ids=input_ids, attention_mask=attention_mask) |
||||
|
||||
# label is needed for casual lm |
||||
def data_gen_for_casual_lm(): |
||||
data = data_gen() |
||||
labels = data["input_ids"].clone() |
||||
data["labels"] = labels |
||||
return data |
||||
|
||||
# transform the output to a dict |
||||
output_transform_fn = lambda x: x |
||||
|
||||
# function to get the loss |
||||
loss_fn = lambda output: output["last_hidden_state"].mean() |
||||
loss_fn_for_casual_lm = lambda output: output["loss"] |
||||
loss_fn_for_seq_classification = lambda output: output["logits"].mean() |
||||
|
||||
config = CohereConfig( |
||||
num_hidden_layers=8, |
||||
hidden_size=32, |
||||
intermediate_size=64, |
||||
num_attention_heads=4, |
||||
max_position_embeddings=128, |
||||
) |
||||
|
||||
if hasattr(config, "pad_token_id"): |
||||
config.pad_token_id = config.eos_token_id |
||||
|
||||
# register the following models |
||||
# transformers.CohereModel, |
||||
# transformers.CohereForCausalLM, |
||||
model_zoo.register( |
||||
name="transformers_command", |
||||
model_fn=lambda: transformers.CohereModel(config), |
||||
data_gen_fn=data_gen, |
||||
output_transform_fn=output_transform_fn, |
||||
loss_fn=loss_fn, |
||||
model_attribute=ModelAttribute(has_control_flow=True), |
||||
) |
||||
model_zoo.register( |
||||
name="transformers_command_for_casual_lm", |
||||
model_fn=lambda: transformers.CohereForCausalLM(config), |
||||
data_gen_fn=data_gen_for_casual_lm, |
||||
output_transform_fn=output_transform_fn, |
||||
loss_fn=loss_fn_for_casual_lm, |
||||
model_attribute=ModelAttribute(has_control_flow=True), |
||||
) |
@ -0,0 +1,161 @@
|
||||
import os |
||||
import random |
||||
|
||||
import numpy as np |
||||
import pytest |
||||
import torch |
||||
import torch.distributed as dist |
||||
from torch.multiprocessing import Manager |
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, LlamaForCausalLM, LlamaTokenizer |
||||
|
||||
import colossalai |
||||
import colossalai.inference.modeling.policy as policy |
||||
from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig |
||||
from colossalai.inference.core.engine import InferenceEngine |
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn |
||||
|
||||
# NOTE: To test a model with the inference engine, you need to provide the path to your |
||||
# local pretrained model weights in the MODEL_MAP dictionary |
||||
MODEL_MAP = { |
||||
"baichuan": { |
||||
"model": AutoModelForCausalLM, |
||||
"tokenizer": AutoTokenizer, |
||||
"policy": policy.NoPaddingBaichuanModelInferPolicy, |
||||
"model_name_or_path": "baichuan-inc/Baichuan2-13B-Base", # provide the path to local model weights |
||||
}, |
||||
"llama": { |
||||
"model": LlamaForCausalLM, |
||||
"tokenizer": LlamaTokenizer, |
||||
"policy": policy.NoPaddingLlamaModelInferPolicy, |
||||
"model_name_or_path": "meta-llama/Llama-2-70b-hf", |
||||
}, |
||||
} |
||||
|
||||
MODELS_TO_TEST = ["llama", "baichuan"] # Specify the models to test |
||||
|
||||
|
||||
@parameterize("model", MODELS_TO_TEST) |
||||
@parameterize("prompt_template", [None, "model_specific"]) |
||||
@parameterize("do_sample", [False]) |
||||
@parameterize("use_cuda_kernel", [True]) |
||||
@pytest.mark.largedist |
||||
@rerun_if_address_is_in_use() |
||||
def test_model(model, prompt_template, do_sample, use_cuda_kernel): |
||||
model_path = MODEL_MAP[model]["model_name_or_path"] |
||||
if not os.path.exists(model_path): |
||||
pytest.skip( |
||||
f"There is no local model address included for {model}, please replace this address with a valid one." |
||||
) |
||||
|
||||
if prompt_template == "model_specific": |
||||
prompt_template = model |
||||
|
||||
model_config = MODEL_MAP[model] |
||||
|
||||
kwargs1 = { |
||||
"model": model, |
||||
"use_engine": True, |
||||
"prompt_template": prompt_template, |
||||
"do_sample": do_sample, |
||||
"policy": model_config["policy"](), |
||||
"use_cuda_kernel": use_cuda_kernel, |
||||
} |
||||
|
||||
kwargs2 = { |
||||
"model": model, |
||||
"use_engine": False, |
||||
"prompt_template": prompt_template, |
||||
"do_sample": do_sample, |
||||
"policy": None, |
||||
"use_cuda_kernel": use_cuda_kernel, |
||||
} |
||||
|
||||
colossal_tp_1_output = run_engine(1, **kwargs1) |
||||
colossal_tp_2_output = run_engine(2, **kwargs1) |
||||
transformer_tp_1_output = run_engine(1, **kwargs2) |
||||
|
||||
for s1, s2, s3 in zip(colossal_tp_1_output, colossal_tp_2_output, transformer_tp_1_output): |
||||
assert s1 == s3, f"\nColossalAI TP=1 Output: {s1}\nTransformers Output: {s3}" |
||||
assert s1 == s2, f"\nColossalAI TP=1 Output: {s1}\nColossalAI TP=2 Output: {s2}" |
||||
|
||||
|
||||
def run_engine(world_size, **kwargs): |
||||
manager = Manager() |
||||
result_list = manager.list([-1] * world_size) # Create a shared list |
||||
spawn(run_dist, world_size, func_to_run=_run_engine, ret=result_list, **kwargs) |
||||
return result_list[0] |
||||
|
||||
|
||||
def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs): |
||||
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") |
||||
|
||||
if ret: |
||||
ret[rank] = func_to_run(**kwargs) |
||||
else: |
||||
func_to_run(**kwargs) |
||||
|
||||
|
||||
def _run_engine(model, use_engine=False, do_sample=False, use_cuda_kernel=False, prompt_template=None, policy=None): |
||||
setup_seed(20) |
||||
model_config = MODEL_MAP[model] |
||||
model_name_or_path = model_config["model_name_or_path"] |
||||
tokenizer = model_config["tokenizer"].from_pretrained(model_name_or_path, use_fast=False, trust_remote_code=True) |
||||
model = model_config["model"].from_pretrained(model_name_or_path, trust_remote_code=True).half().cuda() |
||||
model = model.eval() |
||||
|
||||
inputs = [ |
||||
"Introduce some landmarks in Paris:", |
||||
] |
||||
|
||||
output_len = 38 |
||||
|
||||
if do_sample: |
||||
top_p = 0.5 |
||||
top_k = 50 |
||||
else: |
||||
top_p = None |
||||
top_k = None |
||||
|
||||
if use_engine: |
||||
inference_config = InferenceConfig( |
||||
max_output_len=output_len, |
||||
prompt_template=prompt_template, |
||||
use_cuda_kernel=use_cuda_kernel, |
||||
tp_size=dist.get_world_size(), |
||||
) |
||||
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True, model_policy=policy) |
||||
assert inference_engine.generation_config.max_new_tokens == output_len |
||||
inference_engine.add_request(prompts=inputs) |
||||
assert inference_engine.request_handler._has_waiting() |
||||
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k, max_new_tokens=output_len) |
||||
outputs = inference_engine.generate(generation_config=generation_config) |
||||
else: |
||||
if prompt_template: |
||||
# apply prompt template |
||||
inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs] |
||||
tokenizer.pad_token = tokenizer.eos_token |
||||
tokenizer.pad_token_id = tokenizer.eos_token_id |
||||
inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"] |
||||
inputs = inputs.cuda() |
||||
generation_config = GenerationConfig( |
||||
do_sample=do_sample, |
||||
top_p=top_p, |
||||
top_k=top_k, |
||||
pad_token_id=tokenizer.pad_token_id, |
||||
max_new_tokens=output_len, |
||||
) |
||||
outputs = model.generate(inputs, generation_config=generation_config) |
||||
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) |
||||
return outputs |
||||
|
||||
|
||||
def setup_seed(seed): |
||||
torch.manual_seed(seed) |
||||
torch.random.manual_seed(seed) |
||||
torch.cuda.manual_seed_all(seed) |
||||
np.random.seed(seed) |
||||
random.seed(seed) |
||||
|
||||
|
||||
if __name__ == "__main__": |
||||
test_model() |
@ -0,0 +1,322 @@
|
||||
import os |
||||
|
||||
import pytest |
||||
import torch |
||||
import torch.distributed as dist |
||||
from torch.testing import assert_close |
||||
|
||||
import colossalai |
||||
from colossalai.logging import disable_existing_loggers |
||||
from colossalai.shardformer import PipelineGradientCheckpointConfig |
||||
from colossalai.shardformer.layer.utils import Randomizer |
||||
from colossalai.tensor.d_tensor.api import clear_layout_converter |
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn |
||||
from tests.kit.model_zoo import model_zoo |
||||
from tests.test_shardformer.test_model._utils import ( |
||||
build_model_from_hybrid_plugin, |
||||
check_all_grad_tensors, |
||||
check_loss, |
||||
check_output_hidden_state, |
||||
check_weight, |
||||
get_grad_tensors_for_check, |
||||
run_forward_backward_with_hybrid_plugin, |
||||
unwrap_model, |
||||
) |
||||
|
||||
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" |
||||
|
||||
|
||||
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): |
||||
enable_gradient_checkpointing = test_config.pop("enable_gradient_checkpointing", False) |
||||
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( |
||||
model_fn, loss_fn, test_config |
||||
) |
||||
if enable_gradient_checkpointing: |
||||
# org_model.gradient_checkpointing_enable() |
||||
sharded_model.unwrap().gradient_checkpointing_enable() |
||||
|
||||
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( |
||||
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster |
||||
) |
||||
|
||||
stage_manager = booster.plugin.stage_manager |
||||
tp_group = booster.plugin.tp_group |
||||
|
||||
# unwrap model |
||||
command_model = unwrap_model(org_model, "CohereModel", "model") |
||||
shard_command_model = unwrap_model(sharded_model, "CohereModel", "model") |
||||
|
||||
row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"] |
||||
col_layer_for_check = ["layers[0].self_attn.o_proj"] |
||||
# Here we check the grad of layernorm because an all-reduce operation should be performed during sequence parallelism |
||||
norm_layer_for_check = ["layers[0].input_layernorm", "layers[1].input_layernorm"] |
||||
|
||||
# During pipeline parallelism, we cannot get the grad of norm layer during first stage, so we only check this when pp is not enbaled |
||||
if stage_manager is None: |
||||
norm_layer_for_check.append("norm") |
||||
|
||||
# Check the grad when using ZeRO-1 and ZeRO-2 |
||||
if ( |
||||
booster.plugin.zero_stage in [1, 2] |
||||
and booster.plugin.shard_config.enable_sequence_parallelism |
||||
and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all" |
||||
): |
||||
for p1, p2 in zip(command_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]): |
||||
working_p = sharded_optimizer._param_store.master_to_working_param[id(p2)] |
||||
grads = sharded_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(working_p)) |
||||
grad_index = ( |
||||
0 if sharded_optimizer._grad_store._partition_grads else sharded_optimizer._bucket_store.zero_local_rank |
||||
) |
||||
grad = grads[grad_index] |
||||
sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()] |
||||
assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False) |
||||
|
||||
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step. |
||||
grads_to_check = {} |
||||
if (stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True)) and booster.plugin.zero_stage == 0: |
||||
if test_config["precision"] == "fp32": |
||||
atol, rtol = 1e-6, 1e-4 |
||||
else: |
||||
atol, rtol = 5e-3, 5e-3 |
||||
row_layer_grads = get_grad_tensors_for_check( |
||||
command_model, |
||||
shard_command_model, |
||||
row_layer_for_check, |
||||
tp_group, |
||||
atol=atol, |
||||
rtol=rtol, |
||||
dim=0, |
||||
verbose=False, |
||||
) |
||||
col_layer_grads = get_grad_tensors_for_check( |
||||
command_model, |
||||
shard_command_model, |
||||
col_layer_for_check, |
||||
tp_group, |
||||
atol=atol, |
||||
rtol=rtol, |
||||
dim=1, |
||||
verbose=False, |
||||
) |
||||
norm_layer_grads = get_grad_tensors_for_check( |
||||
command_model, |
||||
shard_command_model, |
||||
norm_layer_for_check, |
||||
tp_group, |
||||
atol=atol, |
||||
rtol=rtol, |
||||
dim=1, |
||||
verbose=False, |
||||
) |
||||
grads_to_check.update(col_layer_grads) |
||||
grads_to_check.update(row_layer_grads) |
||||
grads_to_check.update(norm_layer_grads) |
||||
|
||||
# optimizer executes step |
||||
org_optimizer.step() |
||||
sharded_optimizer.step() |
||||
|
||||
# check last hidden state & loss |
||||
if stage_manager is None or stage_manager.is_last_stage(ignore_chunk=True): |
||||
if test_config["precision"] == "fp32": |
||||
atol, rtol = 1e-5, 1e-3 |
||||
else: |
||||
atol, rtol = 5e-3, 5e-3 |
||||
|
||||
if org_model.__class__.__name__ == "CohereModel": |
||||
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) |
||||
|
||||
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) |
||||
|
||||
# check weights |
||||
if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True): |
||||
if test_config["precision"] == "fp32": |
||||
atol, rtol = 5e-4, 1e-3 |
||||
else: |
||||
atol, rtol = 5e-3, 5e-3 |
||||
check_weight( |
||||
command_model, |
||||
shard_command_model, |
||||
col_layer_for_check, |
||||
tp_group, |
||||
atol=atol, |
||||
rtol=rtol, |
||||
dim=1, |
||||
verbose=False, |
||||
) |
||||
|
||||
# check grads |
||||
check_all_grad_tensors(grads_to_check) |
||||
|
||||
torch.cuda.empty_cache() |
||||
|
||||
|
||||
@parameterize( |
||||
"test_config", |
||||
[ |
||||
{ |
||||
"tp_size": 2, |
||||
"pp_size": 1, |
||||
"num_microbatches": 1, |
||||
"enable_sequence_parallelism": True, |
||||
"sequence_parallelism_mode": "ring", |
||||
"enable_flash_attention": True, |
||||
"use_lazy_init": True, |
||||
"zero_stage": 2, |
||||
"precision": "fp16", |
||||
"initial_scale": 1, |
||||
}, |
||||
{ |
||||
"tp_size": 4, |
||||
"pp_size": 1, |
||||
"num_microbatches": 1, |
||||
"enable_sequence_parallelism": True, |
||||
"sequence_parallelism_mode": "split_gather", |
||||
"enable_flash_attention": False, |
||||
"use_lazy_init": True, |
||||
"precision": "fp16", |
||||
"initial_scale": 1, |
||||
}, |
||||
{ |
||||
"tp_size": 1, |
||||
"pp_size": 1, |
||||
"sp_size": 2, |
||||
"num_microbatches": 1, |
||||
"enable_sequence_parallelism": True, |
||||
"sequence_parallelism_mode": "all_to_all", |
||||
"use_lazy_init": True, |
||||
"zero_stage": 2, |
||||
"precision": "fp16", |
||||
"initial_scale": 1, |
||||
}, |
||||
{ |
||||
"tp_size": 2, |
||||
"pp_size": 2, |
||||
"num_microbatches": 2, |
||||
"enable_all_optimization": True, |
||||
"use_lazy_init": True, |
||||
"precision": "fp16", |
||||
"initial_scale": 1, |
||||
"enable_gradient_checkpointing": True, |
||||
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(gradient_checkpointing_ratio=0.5), |
||||
}, |
||||
{ |
||||
"tp_size": 1, |
||||
"pp_size": 2, |
||||
"num_microbatches": 4, |
||||
"use_lazy_init": False, |
||||
"precision": "fp32", |
||||
"enable_gradient_checkpointing": True, |
||||
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]), |
||||
}, |
||||
{ |
||||
"tp_size": 2, |
||||
"pp_size": 1, |
||||
"enable_all_optimization": True, |
||||
"use_lazy_init": True, |
||||
"zero_stage": 2, |
||||
"precision": "fp16", |
||||
"initial_scale": 1, |
||||
}, |
||||
{ |
||||
"tp_size": 1, |
||||
"pp_size": 2, |
||||
"num_microbatches": 2, |
||||
"enable_all_optimization": True, |
||||
"use_lazy_init": True, |
||||
"zero_stage": 1, |
||||
"precision": "fp16", |
||||
"initial_scale": 1, |
||||
}, |
||||
], |
||||
) |
||||
def run_command_test(test_config): |
||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm") |
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): |
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) |
||||
|
||||
clear_layout_converter() |
||||
Randomizer.reset_index() |
||||
torch.cuda.empty_cache() |
||||
|
||||
|
||||
@parameterize( |
||||
"test_config", |
||||
[ |
||||
{ |
||||
"tp_size": 2, |
||||
"pp_size": 2, |
||||
"num_microbatches": 4, |
||||
"enable_all_optimization": False, |
||||
"use_lazy_init": False, |
||||
"precision": "fp32", |
||||
"initial_scale": 1, |
||||
}, |
||||
{ |
||||
"tp_size": 2, |
||||
"pp_size": 2, |
||||
"num_microbatches": 4, |
||||
"enable_all_optimization": False, |
||||
"use_lazy_init": False, |
||||
"precision": "fp16", |
||||
"zero_stage": 1, |
||||
"initial_scale": 1, |
||||
}, |
||||
{ |
||||
"tp_size": 2, |
||||
"pp_size": 2, |
||||
"pp_style": "interleaved", |
||||
"num_model_chunks": 2, |
||||
"num_microbatches": 4, |
||||
"enable_all_optimization": False, |
||||
"precision": "fp16", |
||||
"zero_stage": 1, |
||||
"initial_scale": 1, |
||||
"enable_gradient_checkpointing": True, |
||||
"gradient_checkpoint_config": PipelineGradientCheckpointConfig( |
||||
num_ckpt_layers_per_stage=[0, 1, 2, 2], |
||||
), |
||||
}, |
||||
], |
||||
) |
||||
def run_command_3d_test(test_config): |
||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm") |
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): |
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) |
||||
|
||||
clear_layout_converter() |
||||
Randomizer.reset_index() |
||||
torch.cuda.empty_cache() |
||||
|
||||
|
||||
def check_command(rank, world_size, port): |
||||
disable_existing_loggers() |
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") |
||||
run_command_test() |
||||
|
||||
|
||||
def check_command_3d(rank, world_size, port): |
||||
disable_existing_loggers() |
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") |
||||
run_command_3d_test() |
||||
|
||||
|
||||
@pytest.mark.dist |
||||
@rerun_if_address_is_in_use() |
||||
@clear_cache_before_run() |
||||
def test_command(): |
||||
spawn(check_command, 4) |
||||
|
||||
|
||||
@pytest.mark.largedist |
||||
@rerun_if_address_is_in_use() |
||||
@clear_cache_before_run() |
||||
def test_command_3d(): |
||||
spawn(check_command_3d, 8) |
||||
|
||||
|
||||
if __name__ == "__main__": |
||||
test_command() |
||||
test_command_3d() |
Loading…
Reference in new issue