add support for bloom (#5008)

pull/5023/head^2
Bin Jia 2023-11-06 09:35:33 +08:00 committed by FoolPlayer
parent f747d13040
commit 48d0a58d10
8 changed files with 721 additions and 25 deletions

View File

@ -1,4 +1,4 @@
from .hybridengine import CaiInferEngine
from .hybridengine.polices import LlamaModelInferPolicy
from .hybridengine.polices import BloomModelInferPolicy, LlamaModelInferPolicy
__all__ = ["CaiInferEngine", "LlamaModelInferPolicy"]
__all__ = ["CaiInferEngine", "LlamaModelInferPolicy", "BloomModelInferPolicy"]

View File

@ -16,6 +16,7 @@ PP_AXIS, TP_AXIS = 0, 1
_supported_models = [
"LlamaForCausalLM",
"BloomForCausalLM",
]
@ -155,12 +156,20 @@ class CaiInferEngine:
def _init_manager(self, model, max_batch_size: int, max_input_len: int, max_output_len: int) -> None:
max_total_token_num = max_batch_size * (max_input_len + max_output_len)
head_dim = model.config.hidden_size // model.config.num_attention_heads
head_num = model.config.num_attention_heads
num_hidden_layers = (
model.config.num_hidden_layers if hasattr(model.config, "num_hidden_layers") else model.config.num_layers
)
layer_num = num_hidden_layers // self.pp_size
if model.config.model_type == "llama":
head_dim = model.config.hidden_size // model.config.num_attention_heads
head_num = model.config.num_attention_heads // self.tp_size
num_hidden_layers = (
model.config.num_hidden_layers
if hasattr(model.config, "num_hidden_layers")
else model.config.num_layers
)
layer_num = num_hidden_layers // self.pp_size
elif model.config.model_type == "bloom":
head_dim = model.config.hidden_size // model.config.n_head
head_num = model.config.n_head // self.tp_size
num_hidden_layers = model.config.n_layer
layer_num = num_hidden_layers // self.pp_size
cache_manager = MemoryManager(max_total_token_num, self.dtype, head_num, head_dim, layer_num)
return cache_manager

View File

@ -0,0 +1,452 @@
import math
import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.distributed as dist
from torch.nn import functional as F
from transformers.models.bloom.modeling_bloom import (
BaseModelOutputWithPastAndCrossAttentions,
BloomAttention,
BloomBlock,
BloomForCausalLM,
BloomModel,
)
from transformers.utils import logging
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.kernel.triton import bloom_context_attn_fwd, copy_kv_cache_to_dest, token_attention_fwd
from colossalai.pipeline.stage_manager import PipelineStageManager
try:
from lightllm.models.bloom.triton_kernel.context_flashattention_nopad import (
context_attention_fwd as lightllm_bloom_context_attention_fwd,
)
HAS_LIGHTLLM_KERNEL = True
except:
HAS_LIGHTLLM_KERNEL = False
def generate_alibi(n_head, dtype=torch.float16):
"""
This method is adapted from `_generate_alibi` function
in `lightllm/models/bloom/layer_weights/transformer_layer_weight.py`
of the ModelTC/lightllm GitHub repository.
This method is originally the `build_alibi_tensor` function
in `transformers/models/bloom/modeling_bloom.py`
of the huggingface/transformers GitHub repository.
"""
def get_slopes_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
return [start * start**i for i in range(n)]
def get_slopes(n):
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2 ** math.floor(math.log2(n))
slopes_power_of_2 = get_slopes_power_of_2(closest_power_of_2)
slopes_double = get_slopes(2 * closest_power_of_2)
slopes_combined = slopes_power_of_2 + slopes_double[0::2][: n - closest_power_of_2]
return slopes_combined
slopes = get_slopes(n_head)
return torch.tensor(slopes, dtype=dtype)
class BloomInferenceForwards:
"""
This class serves a micro library for bloom inference forwards.
We intend to replace the forward methods for BloomForCausalLM, BloomModel, BloomBlock, and BloomAttention,
as well as prepare_inputs_for_generation method for BloomForCausalLM.
For future improvement, we might want to skip replacing methods for BloomForCausalLM,
and call BloomModel.forward iteratively in TpInferEngine
"""
@staticmethod
def bloom_for_causal_lm_forward(
self: BloomForCausalLM,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = False,
infer_state: BatchInferState = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
tp_group: Optional[dist.ProcessGroup] = None,
**deprecated_arguments,
):
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
logger = logging.get_logger(__name__)
if deprecated_arguments.pop("position_ids", False) is not False:
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
warnings.warn(
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
" passing `position_ids`.",
FutureWarning,
)
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions:
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False
if output_hidden_states:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
# If is first stage and hidden_states is not None, go throught lm_head first
if stage_manager.is_first_stage() and hidden_states is not None:
lm_logits = self.lm_head(hidden_states)
return {"logits": lm_logits}
outputs = BloomInferenceForwards.bloom_model_forward(
self.transformer,
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
infer_state=infer_state,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
tp_group=tp_group,
)
return outputs
@staticmethod
def bloom_model_forward(
self: BloomModel,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = None,
infer_state: BatchInferState = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
tp_group: Optional[dist.ProcessGroup] = None,
**deprecated_arguments,
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
logger = logging.get_logger(__name__)
# add warnings here
if output_attentions:
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False
if output_hidden_states:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
if use_cache:
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
use_cache = False
if deprecated_arguments.pop("position_ids", False) is not False:
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
warnings.warn(
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
" passing `position_ids`.",
FutureWarning,
)
if len(deprecated_arguments) > 0:
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape batch_size x num_heads x N x N
# head_mask has shape n_layer x batch x num_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
# first stage
if stage_manager.is_first_stage():
# check inputs and inputs embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
# other stage
else:
input_shape = hidden_states.shape[:-1]
batch_size, seq_length = input_shape
if infer_state.is_context_stage:
past_key_values_length = 0
else:
past_key_values_length = infer_state.max_len_in_batch - 1
if seq_length != 1:
# prefill stage
infer_state.is_context_stage = True # set prefill stage, notify attention layer
infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num)
BatchInferState.init_block_loc(
infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index
)
else:
infer_state.is_context_stage = False
alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size)
if alloc_mem is not None:
infer_state.decode_is_contiguous = True
infer_state.decode_mem_index = alloc_mem[0]
infer_state.decode_mem_start = alloc_mem[1]
infer_state.decode_mem_end = alloc_mem[2]
infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index
else:
print(f" *** Encountered allocation non-contiguous")
print(f" infer_state.max_len_in_batch : {infer_state.max_len_in_batch}")
infer_state.decode_is_contiguous = False
alloc_mem = infer_state.cache_manager.alloc(batch_size)
infer_state.decode_mem_index = alloc_mem
infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index
if attention_mask is None:
attention_mask = torch.ones((batch_size, infer_state.max_len_in_batch), device=hidden_states.device)
else:
attention_mask = attention_mask.to(hidden_states.device)
# NOTE revise: we might want to store a single 1D alibi(length is #heads) in model,
# or store to BatchInferState to prevent re-calculating
# When we have multiple process group (e.g. dp together with tp), we need to pass the pg to here
tp_size = dist.get_world_size(tp_group) if tp_group is not None else 1
curr_tp_rank = dist.get_rank(tp_group) if tp_group is not None else 0
alibi = (
generate_alibi(self.num_heads * tp_size)
.contiguous()[curr_tp_rank * self.num_heads : (curr_tp_rank + 1) * self.num_heads]
.cuda()
)
causal_mask = self._prepare_attn_mask(
attention_mask,
input_shape=(batch_size, seq_length),
past_key_values_length=past_key_values_length,
)
infer_state.decode_layer_id = 0
start_idx, end_idx = stage_index[0], stage_index[1]
if past_key_values is None:
past_key_values = tuple([None] * (end_idx - start_idx + 1))
for idx, past_key_value in zip(range(start_idx, end_idx), past_key_values):
block = self.h[idx]
outputs = block(
hidden_states,
layer_past=past_key_value,
attention_mask=causal_mask,
head_mask=head_mask[idx],
use_cache=use_cache,
output_attentions=output_attentions,
alibi=alibi,
infer_state=infer_state,
)
infer_state.decode_layer_id += 1
hidden_states = outputs[0]
if stage_manager.is_last_stage() or stage_manager.num_stages == 1:
hidden_states = self.ln_f(hidden_states)
# update indices
infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
infer_state.seq_len += 1
infer_state.max_len_in_batch += 1
# always return dict for imediate stage
return {"hidden_states": hidden_states}
@staticmethod
def bloom_block_forward(
self: BloomBlock,
hidden_states: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
infer_state: Optional[BatchInferState] = None,
):
# hidden_states: [batch_size, seq_length, hidden_size]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Layer norm post the self attention.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
# Self attention.
attn_outputs = self.self_attention(
layernorm_output,
residual,
layer_past=layer_past,
attention_mask=attention_mask,
alibi=alibi,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
infer_state=infer_state,
)
attention_output = attn_outputs[0]
outputs = attn_outputs[1:]
layernorm_output = self.post_attention_layernorm(attention_output)
# Get residual
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = attention_output
# MLP.
output = self.mlp(layernorm_output, residual)
if use_cache:
outputs = (output,) + outputs
else:
outputs = (output,) + outputs[1:]
return outputs # hidden_states, present, attentions
@staticmethod
def bloom_attention_forward(
self: BloomAttention,
hidden_states: torch.Tensor,
residual: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
infer_state: Optional[BatchInferState] = None,
):
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
batch_size, q_length, H, D_HEAD = query_layer.shape
k = key_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1
v = value_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1
mem_manager = infer_state.cache_manager
layer_id = infer_state.decode_layer_id
if infer_state.is_context_stage:
# context process
max_input_len = q_length
b_start_loc = infer_state.start_loc
b_seq_len = infer_state.seq_len[:batch_size]
q = query_layer.reshape(-1, H, D_HEAD)
copy_kv_cache_to_dest(k, infer_state.context_mem_index, mem_manager.key_buffer[layer_id])
copy_kv_cache_to_dest(v, infer_state.context_mem_index, mem_manager.value_buffer[layer_id])
# output = self.output[:batch_size*q_length, :, :]
output = torch.empty_like(q)
if HAS_LIGHTLLM_KERNEL:
lightllm_bloom_context_attention_fwd(q, k, v, output, alibi, b_start_loc, b_seq_len, max_input_len)
else:
bloom_context_attn_fwd(q, k, v, output, b_start_loc, b_seq_len, max_input_len, alibi)
context_layer = output.view(batch_size, q_length, H * D_HEAD)
else:
# query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
# need shape: batch_size, H, D_HEAD (q_length == 1), input q shape : (batch_size, q_length(1), H, D_HEAD)
assert q_length == 1, "for non-context process, we only support q_length == 1"
q = query_layer.reshape(-1, H, D_HEAD)
if infer_state.decode_is_contiguous:
# if decode is contiguous, then we copy to key cache and value cache in cache manager directly
cache_k = infer_state.cache_manager.key_buffer[layer_id][
infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
]
cache_v = infer_state.cache_manager.value_buffer[layer_id][
infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
]
cache_k.copy_(k)
cache_v.copy_(v)
else:
# if decode is not contiguous, use triton kernel to copy key and value cache
# k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head]
copy_kv_cache_to_dest(k, infer_state.decode_mem_index, mem_manager.key_buffer[layer_id])
copy_kv_cache_to_dest(v, infer_state.decode_mem_index, mem_manager.value_buffer[layer_id])
b_start_loc = infer_state.start_loc
b_loc = infer_state.block_loc
b_seq_len = infer_state.seq_len
output = torch.empty_like(q)
token_attention_fwd(
q,
mem_manager.key_buffer[layer_id],
mem_manager.value_buffer[layer_id],
output,
b_loc,
b_start_loc,
b_seq_len,
infer_state.max_len_in_batch,
alibi,
)
context_layer = output.view(batch_size, q_length, H * D_HEAD)
# NOTE: always set present as none for now, instead of returning past key value to the next decoding,
# we create the past key value pair from the cache manager
present = None
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
if self.pretraining_tp > 1 and self.slow_but_exact:
slices = self.hidden_size / self.pretraining_tp
output_tensor = torch.zeros_like(context_layer)
for i in range(self.pretraining_tp):
output_tensor = output_tensor + F.linear(
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
)
else:
output_tensor = self.dense(context_layer)
# dropout is not required here during inference
output_tensor = residual + output_tensor
outputs = (output_tensor, present)
assert output_attentions is False, "we do not support output_attentions at this time"
return outputs

View File

@ -167,7 +167,7 @@ class LlamaInferenceForwards:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
# If is first stage and after warmup, go throught lm_head first
# If is first stage and hidden_states is None, go throught lm_head first
if stage_manager.is_first_stage() and hidden_states is not None:
lm_logits = self.lm_head(hidden_states)
return {"logits": lm_logits}
@ -327,15 +327,6 @@ class LlamaInferenceForwards:
infer_state.seq_len += 1
infer_state.max_len_in_batch += 1
# if not return_dict:
# return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
# return BaseModelOutputWithPast(
# last_hidden_state=hidden_states,
# past_key_values=next_cache,
# hidden_states=all_hidden_states,
# attentions=all_self_attns,
# )
return {"hidden_states": hidden_states}
@staticmethod

View File

@ -1,3 +1,4 @@
from .bloom import BloomModelInferPolicy
from .llama import LlamaModelInferPolicy
__all__ = ["LlamaModelInferPolicy"]
__all__ = ["LlamaModelInferPolicy", "BloomModelInferPolicy"]

View File

@ -0,0 +1,127 @@
from functools import partial
from typing import List
import torch
from torch.nn import LayerNorm, Module
import colossalai.shardformer.layer as col_nn
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy
from ..modeling.bloom import BloomInferenceForwards
try:
from colossalai.kernel.triton import layer_norm
HAS_TRITON_NORM = True
except:
print("Some of our kernels require triton. You might want to install triton from https://github.com/openai/triton")
HAS_TRITON_NORM = False
def get_triton_layernorm_forward():
if HAS_TRITON_NORM:
def _triton_layernorm_forward(self: LayerNorm, hidden_states: torch.Tensor):
return layer_norm(hidden_states, self.weight.data, self.bias, self.eps)
return _triton_layernorm_forward
else:
return None
class BloomModelInferPolicy(BloomForCausalLMPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomForCausalLM, BloomModel
policy = super().module_policy()
if self.shard_config.inference_gptq:
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
policy[BloomBlock] = ModulePolicyDescription(
attribute_replacement={
"self_attention.hidden_size": self.model.config.hidden_size
// self.shard_config.tensor_parallel_size,
"self_attention.split_size": self.model.config.hidden_size
// self.shard_config.tensor_parallel_size,
"self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
target_module=ColCaiQuantLinear,
kwargs={"split_num": 3},
),
SubModuleReplacementDescription(
suffix="self_attention.dense", target_module=RowCaiQuantLinear, kwargs={"split_num": 1}
),
SubModuleReplacementDescription(
suffix="self_attention.attention_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="mlp.dense_h_to_4h", target_module=ColCaiQuantLinear, kwargs={"split_num": 1}
),
SubModuleReplacementDescription(
suffix="mlp.dense_4h_to_h", target_module=RowCaiQuantLinear, kwargs={"split_num": 1}
),
],
)
# NOTE set inference mode to shard config
self.shard_config._infer()
# set as default, in inference we also use pipeline style forward, just setting stage as 1
self.set_pipeline_forward(
model_cls=BloomForCausalLM,
new_forward=partial(
BloomInferenceForwards.bloom_for_causal_lm_forward,
tp_group=self.shard_config.tensor_parallel_process_group,
),
policy=policy,
)
method_replacement = {"forward": BloomInferenceForwards.bloom_model_forward}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=BloomModel)
method_replacement = {"forward": BloomInferenceForwards.bloom_block_forward}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=BloomBlock)
method_replacement = {"forward": BloomInferenceForwards.bloom_attention_forward}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=BloomAttention
)
if HAS_TRITON_NORM:
infer_method = get_triton_layernorm_forward()
method_replacement = {"forward": partial(infer_method)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LayerNorm
)
return policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
assert self.pipeline_stage_manager is not None
if self.model.__class__.__name__ == "BloomModel":
module = self.model
else:
module = self.model.transformer
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
if stage_manager.is_first_stage():
held_layers.append(module.word_embeddings)
held_layers.append(module.word_embeddings_layernorm)
held_layers.append(self.model.lm_head)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
held_layers.extend(module.h[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.ln_f)
return held_layers

View File

@ -112,11 +112,11 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
description=method_replacement, policy=policy, target_key=LlamaAttention
)
if self.pipeline_stage_manager:
# set None as default
self.set_pipeline_forward(
model_cls=LlamaForCausalLM, new_forward=LlamaInferenceForwards.llama_causal_lm_forward, policy=policy
)
# set as default, in inference we also use pipeline style forward, just setting stage as 1
self.set_pipeline_forward(
model_cls=LlamaForCausalLM, new_forward=LlamaInferenceForwards.llama_causal_lm_forward, policy=policy
)
infer_forward = None
if HAS_TRITON_RMSNORM:
infer_forward = get_triton_rmsnorm_forward()
@ -135,8 +135,22 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
assert self.pipeline_stage_manager is not None
if self.model.__class__.__name__ == "LlamaModel":
module = self.model
else:
module = self.model.model
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
held_layers = []
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens)
held_layers.append(self.model.lm_head)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.norm)
return held_layers

View File

@ -0,0 +1,102 @@
import pytest
import torch
import torch.distributed as dist
import transformers
from packaging import version
import colossalai
from colossalai.inference import BloomModelInferPolicy, CaiInferEngine
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
def data_gen():
input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
return dict(input_ids=input_ids, attention_mask=attention_mask)
inputs = data_gen()
for k, v in inputs.items():
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
new_shape = [1] * v.dim()
new_shape[0] = 16
inputs[k] = v.to("cuda").repeat(*new_shape)
def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
model = transformers.BloomForCausalLM(
transformers.BloomConfig(vocab_size=20000, hidden_size=512, n_head=4, n_layer=4)
)
engine = CaiInferEngine(
tp_size=tp_size,
pp_size=pp_size,
model=model,
model_policy=BloomModelInferPolicy(),
max_output_len=max_output_len,
micro_batch_size=micro_batch_size,
)
output = engine.inference(inputs)
if dist.get_rank() == 0:
assert len(output[0]) == max_output_len, f"{len(output)}, {max_output_len}"
@parameterize("tp_size", [1])
@parameterize("pp_size", [2])
@parameterize("max_output_len", [4])
@parameterize("micro_batch_size", [1])
@clear_cache_before_run()
def run_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
torch.cuda.empty_cache()
@parameterize("tp_size", [2])
@parameterize("pp_size", [2])
@parameterize("max_output_len", [4])
@parameterize("micro_batch_size", [1])
@clear_cache_before_run()
def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
torch.cuda.empty_cache()
@parameterize("tp_size", [2])
@parameterize("pp_size", [1])
@parameterize("max_output_len", [2])
@parameterize("micro_batch_size", [1])
@clear_cache_before_run()
def run_tp_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
torch.cuda.empty_cache()
def check_pipeline_inference(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_pipeline_inference_test()
def check_tp_pipeline_inference(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_tp_pipeline_inference_test()
def check_tp_inference(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_tp_inference_test()
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_pipeline_inference():
spawn(check_pipeline_inference, nprocs=2)
spawn(check_tp_pipeline_inference, nprocs=4)
spawn(check_tp_inference, nprocs=2)
if __name__ == "__main__":
test_pipeline_inference()