Browse Source

[infer] Add Bloom inference policy and replaced methods (#4512)

* add bloom inference methods and policy

* enable pass BatchInferState from model forward

* revise bloom infer layers/policies

* add engine for inference (draft)

* add test for bloom infer

* fix bloom infer policy and flow

* revise bloom test

* fix bloom file path

* remove unused codes

* fix bloom modeling

* fix dir typo

* fix trivial

* fix policy

* clean pr

* trivial fix
pull/4552/head
Yuanheng Zhao 1 year ago committed by GitHub
parent
commit
17cfa57140
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 6
      colossalai/inference/tensor_parallel/__init__.py
  2. 9
      colossalai/inference/tensor_parallel/engine.py
  3. 3
      colossalai/inference/tensor_parallel/modeling/__init__.py
  4. 559
      colossalai/inference/tensor_parallel/modeling/bloom.py
  5. 4
      colossalai/inference/tensor_parallel/policies/__init__.py
  6. 44
      colossalai/inference/tensor_parallel/policies/bloom.py
  7. 17
      colossalai/inference/tensor_parallel/policies/llama.py
  8. 3
      colossalai/inference/tensor_parallel/pollcies/__init__.py
  9. 8
      colossalai/shardformer/policies/auto_policy.py
  10. 60
      tests/test_infer/test_bloom_infer.py

6
colossalai/inference/tensor_parallel/__init__.py

@ -1,6 +1,4 @@
from .modeling.llama import LlamaInferenceForwards
from .pollcies.llama import LlamaModelInferPolicy
from .engine import TPInferEngine
from .kvcache_manager import MemoryManager
__all__ = ['LlamaInferenceForwards', 'LlamaModelInferPolicy', 'MemoryManager', 'TPInferEngine']
__all__ = ['MemoryManager', 'TPInferEngine']

9
colossalai/inference/tensor_parallel/engine.py

@ -141,7 +141,6 @@ class TPInferEngine:
outputs = self.sharded_model.generate(**input_tokens, **generate_kwargs, early_stopping=False)
print(f"outputs.shape {outputs.shape}")
return outputs
def prepare_batch_state(self, inputs) -> BatchInferState:
@ -193,11 +192,7 @@ class TPInferEngine:
start_index += curr_seq_len
max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch
print(" 666 ", max_len_in_batch)
block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len),
dtype=torch.long,
device='cuda')
block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device='cuda')
batch_infer_state = BatchInferState(batch_size, max_len_in_batch)
batch_infer_state.seq_len = seq_lengths.to('cuda') # might want to assign specific device
batch_infer_state.start_loc = seq_start_indexes.to('cuda')
@ -251,4 +246,4 @@ class TPInferEngine:
# => put information already recorded in batchinferstate and pass it to model forward
# => clear records in engine
def add_request():
raise NotImplementedError()
raise NotImplementedError()

3
colossalai/inference/tensor_parallel/modeling/__init__.py

@ -1,3 +1,4 @@
from .bloom import BloomInferenceForwards
from .llama import LlamaInferenceForwards
__all__ = ['LlamaInferenceForwards']
__all__ = ['BloomInferenceForwards', 'LlamaInferenceForwards']

559
colossalai/inference/tensor_parallel/modeling/bloom.py

@ -0,0 +1,559 @@
import math
import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.distributed as dist
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.nn import functional as F
from transformers.models.bloom.modeling_bloom import (
BaseModelOutputWithPastAndCrossAttentions,
BloomAttention,
BloomBlock,
BloomForCausalLM,
BloomModel,
CausalLMOutputWithCrossAttentions,
)
from transformers.utils import logging
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.kernel.triton.context_attention import bloom_context_attn_fwd
from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest
from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd
def generate_alibi(n_head, dtype=torch.float16):
"""
This method is originally the `build_alibi_tensor` function
in `transformers/models/bloom/modeling_bloom.py`
of the huggingface/transformers GitHub repository.
Copyright 2023 ModelTC Team
Copyright 2022 HuggingFace Inc. team and BigScience workshop
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
def get_slopes(n):
def get_slopes_power_of_2(n):
start = 2**(-(2**-(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2**math.floor(math.log2(n))
return (get_slopes_power_of_2(closest_power_of_2) +
get_slopes(2 * closest_power_of_2)[0::2][:n - closest_power_of_2])
slopes = torch.Tensor(get_slopes(n_head))
head_alibi = slopes.to(dtype)
return head_alibi # 1 * num_heads
def generate_alibi_2(n_head, dtype=torch.float16):
def get_slopes_power_of_2(n):
start = 2**(-(2**-(math.log2(n) - 3)))
return [start * start**i for i in range(n)]
def get_slopes(n):
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2**math.floor(math.log2(n))
slopes_power_of_2 = get_slopes_power_of_2(closest_power_of_2)
slopes_double = get_slopes(2 * closest_power_of_2)
slopes_combined = slopes_power_of_2 + slopes_double[0::2][:n - closest_power_of_2]
return slopes_combined
slopes = torch.tensor(get_slopes(n_head), dtype=dtype)
return slopes
class BloomInferenceForwards:
"""
This class serves a micro library for bloom inference forwards
"""
@staticmethod
def bloom_model_forward(
self: BloomModel,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
infer_state: Optional[BatchInferState] = None,
**deprecated_arguments,
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
logger = logging.get_logger(__name__)
if deprecated_arguments.pop("position_ids", False) is not False:
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
warnings.warn(
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
" passing `position_ids`.",
FutureWarning,
)
if len(deprecated_arguments) > 0:
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
# still need to keep past_key_values to fit original forward flow
if past_key_values is None:
past_key_values = tuple([None] * len(self.h))
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape batch_size x num_heads x N x N
# head_mask has shape n_layer x batch x num_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
use_cache = False
# NOTE determine if BatchInferState is passed in via arg
# if not, get the attr binded to the model
# We might wantto remove setattr later
if infer_state is None:
assert hasattr(self, 'infer_state')
infer_state = self.infer_state
# Compute alibi tensor: check build_alibi_tensor documentation
seq_length_with_past = seq_length
past_key_values_length = 0
# if self.cache_manager.past_key_values_length > 0:
if infer_state.cache_manager.past_key_values_length > 0:
# update the past key values length in cache manager,
# TODO use BatchInferState.past_key_values_length instead the one in cache manager
past_key_values_length = infer_state.cache_manager.past_key_values_length
seq_length_with_past = seq_length_with_past + past_key_values_length
# infer_state.cache_manager = self.cache_manager
if use_cache and seq_length != 1:
# prefill stage
infer_state.is_context_stage = True # set prefill stage, notify attention layer
infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num)
BatchInferState.init_block_loc(infer_state.block_loc, infer_state.seq_len, seq_length,
infer_state.context_mem_index)
else:
infer_state.is_context_stage = False
alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size)
if alloc_mem is not None:
infer_state.decode_is_contiguous = True
infer_state.decode_mem_index = alloc_mem[0]
infer_state.decode_mem_start = alloc_mem[1]
infer_state.decode_mem_end = alloc_mem[2]
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
else:
print(f" *** Encountered allocation non-contiguous")
print(
f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}"
)
infer_state.decode_is_contiguous = False
alloc_mem = infer_state.cache_manager.alloc(batch_size)
infer_state.decode_mem_index = alloc_mem
# infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
# infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
else:
attention_mask = attention_mask.to(hidden_states.device)
# TODO revise: we might want to store a single 1D alibi(length is #heads) in model,
# or store to BatchInferState to prevent re-calculating
# When we have multiple process group (e.g. dp together with tp), we need to pass the pg to here
# alibi = generate_alibi(self.num_heads).contiguous().cuda()
tp_size = dist.get_world_size()
curr_tp_rank = dist.get_rank()
alibi = generate_alibi(self.num_heads * tp_size).contiguous()[curr_tp_rank * self.num_heads:(curr_tp_rank + 1) *
self.num_heads].cuda()
causal_mask = self._prepare_attn_mask(
attention_mask,
input_shape=(batch_size, seq_length),
past_key_values_length=past_key_values_length,
)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
# FIXME: currently our KV cache manager does not handle this condition
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
alibi,
causal_mask,
layer_past,
head_mask[i],
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=causal_mask,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
alibi=alibi,
infer_state=infer_state,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
# Add last hidden state
hidden_states = self.ln_f(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# update indices of kv cache block
# TODO: might want to remove this part, instead, better to pass the BatchInferState from model forward,
# and update these information in engine.generate after model foward called
infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
infer_state.seq_len += 1
infer_state.decode_layer_id = 0
if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents, # should always be (None, None, ..., None)
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
@staticmethod
def bloom_for_causal_lm_forward(self: BloomForCausalLM,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
infer_state: Optional[BatchInferState] = None,
**deprecated_arguments):
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
logger = logging.get_logger(__name__)
if deprecated_arguments.pop("position_ids", False) is not False:
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
warnings.warn(
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
" passing `position_ids`.",
FutureWarning,
)
if len(deprecated_arguments) > 0:
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = BloomInferenceForwards.bloom_model_forward(self.transformer,
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
infer_state=infer_state)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
batch_size, seq_length, vocab_size = shift_logits.shape
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(batch_size * seq_length, vocab_size),
shift_labels.view(batch_size * seq_length))
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@staticmethod
def bloom_for_causal_lm_prepare_inputs_for_generation(
self: BloomForCausalLM,
input_ids: torch.LongTensor,
past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> dict:
# only last token for input_ids if past is not None
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
# NOTE we won't use past key values here
# the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed
# if past_key_values[0][0].shape[0] == input_ids.shape[0]:
# past_key_values = self._convert_to_bloom_cache(past_key_values)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update({
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
})
return model_inputs
# replace decoder layer forward:
# used to replace BloomBlock.forward
@staticmethod
def bloom_block_forward(
self: BloomBlock,
hidden_states: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
infer_state: Optional[BatchInferState] = None,
):
# hidden_states: [batch_size, seq_length, hidden_size]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Layer norm post the self attention.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
# Self attention.
attn_outputs = self.self_attention(
layernorm_output,
residual,
layer_past=layer_past,
attention_mask=attention_mask,
alibi=alibi,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
infer_state=infer_state,
)
attention_output = attn_outputs[0]
outputs = attn_outputs[1:]
layernorm_output = self.post_attention_layernorm(attention_output)
# Get residual
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = attention_output
# MLP.
output = self.mlp(layernorm_output, residual)
if use_cache:
outputs = (output,) + outputs
else:
outputs = (output,) + outputs[1:]
return outputs # hidden_states, present, attentions
# replace attention forward:
# used to replace BloomAttention.forward
@staticmethod
def bloom_attention_forward(
self: BloomAttention,
hidden_states: torch.Tensor,
residual: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
infer_state: Optional[BatchInferState] = None,
):
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
batch_size, q_length, H, D_HEAD = query_layer.shape
k = key_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1
v = value_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1
mem_manager = infer_state.cache_manager
layer_id = infer_state.decode_layer_id
if infer_state.is_context_stage:
# context process
max_input_len = q_length
b_start_loc = infer_state.start_loc
b_seq_len = infer_state.seq_len[:batch_size]
q = query_layer.reshape(-1, H, D_HEAD)
copy_kv_cache_to_dest(k, infer_state.context_mem_index, mem_manager.key_buffer[layer_id])
copy_kv_cache_to_dest(v, infer_state.context_mem_index, mem_manager.value_buffer[layer_id])
# output = self.output[:batch_size*q_length, :, :]
output = torch.empty_like(q)
bloom_context_attn_fwd(q, k, v, output, b_start_loc, b_seq_len, max_input_len, alibi)
context_layer = output.view(batch_size, q_length, H * D_HEAD)
# record the length of past key values cache when entering the first attention layer in bloom block,
# since we won't return past_key_value_cache right now
if layer_id == 0: # once per model.forward
infer_state.cache_manager.past_key_values_length = q_length # seq_len
else:
# query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
# need shape: batch_size, H, D_HEAD (q_length == 1), input q shape : (batch_size, q_length(1), H, D_HEAD)
assert q_length == 1, "for non-context process, we only support q_length == 1"
q = query_layer.reshape(-1, H, D_HEAD)
if infer_state.decode_is_contiguous:
# if decode is contiguous, then we copy to key cache and value cache in cache manager directly
cache_k = infer_state.cache_manager.key_buffer[layer_id][
infer_state.decode_mem_start:infer_state.decode_mem_end, :, :]
cache_v = infer_state.cache_manager.value_buffer[layer_id][
infer_state.decode_mem_start:infer_state.decode_mem_end, :, :]
cache_k.copy_(k)
cache_v.copy_(v)
else:
# if decode is not contiguous, use triton kernel to copy key and value cache
# k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head]
copy_kv_cache_to_dest(k, infer_state.decode_mem_index, mem_manager.key_buffer[layer_id])
copy_kv_cache_to_dest(v, infer_state.decode_mem_index, mem_manager.value_buffer[layer_id])
b_start_loc = infer_state.start_loc[:batch_size]
b_loc = infer_state.block_loc[:batch_size, :]
b_seq_len = infer_state.seq_len[:batch_size]
max_len_in_batch = mem_manager.past_key_values_length + q_length
output = torch.empty_like(q)
token_attention_fwd(q, mem_manager.key_buffer[layer_id], mem_manager.value_buffer[layer_id], output, b_loc,
b_start_loc, b_seq_len, max_len_in_batch, alibi)
context_layer = output.view(batch_size, q_length, H * D_HEAD)
if layer_id == 0: # once per model.forward
assert infer_state.cache_manager.past_key_values_length != 0
infer_state.cache_manager.past_key_values_length += q_length # += 1
# update layer id
infer_state.decode_layer_id += 1
# NOTE: always set present as none for now, instead of returning past key value to the next decoding,
# we create the past key value pair from the cache manager
present = None
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
if self.pretraining_tp > 1 and self.slow_but_exact:
slices = self.hidden_size / self.pretraining_tp
output_tensor = torch.zeros_like(context_layer)
for i in range(self.pretraining_tp):
output_tensor = output_tensor + F.linear(
context_layer[:, :, int(i * slices):int((i + 1) * slices)],
self.dense.weight[:, int(i * slices):int((i + 1) * slices)],
)
else:
output_tensor = self.dense(context_layer)
# dropout is not required here during inference
output_tensor = residual + output_tensor
outputs = (output_tensor, present)
assert output_attentions is False, "we do not support output_attentions at this time"
return outputs

4
colossalai/inference/tensor_parallel/policies/__init__.py

@ -0,0 +1,4 @@
from .bloom import BloomModelInferPolicy
from .llama import LlamaModelInferPolicy
__all__ = ['BloomModelInferPolicy', 'LlamaModelInferPolicy']

44
colossalai/inference/tensor_parallel/policies/bloom.py

@ -0,0 +1,44 @@
from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy
from ..modeling.bloom import BloomInferenceForwards
class BloomModelInferPolicy(BloomForCausalLMPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomForCausalLM, BloomModel
policy = super().module_policy()
# NOTE set inference mode to shard config
self.shard_config._infer()
if self.shard_config.enable_tensor_parallelism:
method_replacement = {
'forward':
BloomInferenceForwards.bloom_for_causal_lm_forward,
'prepare_inputs_for_generation':
BloomInferenceForwards.bloom_for_causal_lm_prepare_inputs_for_generation
}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=BloomForCausalLM)
method_replacement = {'forward': BloomInferenceForwards.bloom_model_forward}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=BloomModel)
method_replacement = {'forward': BloomInferenceForwards.bloom_block_forward}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=BloomBlock)
method_replacement = {'forward': BloomInferenceForwards.bloom_attention_forward}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=BloomAttention)
return policy

17
colossalai/inference/tensor_parallel/pollcies/llama.py → colossalai/inference/tensor_parallel/policies/llama.py

@ -2,7 +2,8 @@ from functools import partial
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
from ..modeling.llama import LlamaInferenceForwards
from ..modeling.llama import LlamaInferenceForwards
class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
@ -23,13 +24,17 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
infer_forward = LlamaInferenceForwards.llama_model_forward
method_replacement = {'forward': partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel)
infer_forward = LlamaInferenceForwards.llama_decoder_layer_forward
method_replacement = {'forward': partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaDecoderLayer)
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=LlamaDecoderLayer)
infer_forward = LlamaInferenceForwards.llama_flash_attn_kvcache_forward
method_replacement = {'forward': partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaAttention)
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=LlamaAttention)
return policy
return policy

3
colossalai/inference/tensor_parallel/pollcies/__init__.py

@ -1,3 +0,0 @@
from .llama import LlamaModelInferPolicy
__all__ = ['LlamaModelInferPolicy']

8
colossalai/shardformer/policies/auto_policy.py

@ -137,6 +137,11 @@ _INFER_POLICY_LIST = {
PolicyLocation(file_name="llama", class_name="LlamaModelInferPolicy"),
"transformers.models.llama.modeling_llama.LlamaForCausalLM":
PolicyLocation(file_name="llama", class_name="LlamaModelInferPolicy"),
# Bloom
"transformers.models.bloom.modeling_bloom.BloomModel":
PolicyLocation(file_name="bloom", class_name="BloomModelInferPolicy"),
"transformers.models.bloom.modeling_bloom.BloomForCausalLM":
PolicyLocation(file_name="bloom", class_name="BloomModelInferPolicy"),
}
@ -144,9 +149,8 @@ def import_policy(policy_location: PolicyLocation, inference_only: Optional[bool
"""
Dynamically import a Policy class based on the policy location.
"""
if inference_only:
module_name = f"colossalai.inference.tensor_parallel.pollcies.{policy_location.file_name}"
module_name = f"colossalai.inference.tensor_parallel.policies.{policy_location.file_name}"
else:
module_name = f"colossalai.shardformer.policies.{policy_location.file_name}"
module = importlib.import_module(module_name)

60
tests/test_infer/test_bloom_infer.py

@ -0,0 +1,60 @@
import pytest
import torch
import torch.distributed as dist
from transformers import AutoModelForCausalLM, AutoTokenizer, BloomForCausalLM
import colossalai
from colossalai.inference.tensor_parallel import TPInferEngine
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
TP_SIZE = 2
MAX_BATCH_SIZE = 4
MAX_INPUT_LEN = 16
MAX_OUTPUT_LEN = 32
def run():
model_path = "/data3/data/model_eval_for_commerical_use/phoenix-inst-chat-7b"
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token
text = "Introduce some landmarks in Beijing"
input_ids = tokenizer.batch_encode_plus([text], return_tensors='pt')
model = BloomForCausalLM.from_pretrained(model_path, pad_token_id=tokenizer.eos_token_id)
model = model.half()
model.to(torch.cuda.current_device())
shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True)
shardformer = ShardFormer(shard_config=shard_config)
infer_engine = TPInferEngine(model, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
infer_engine.prepare_with_shard_config(shard_config=shard_config)
infer_engine.shard_model_by(shardformer)
generate_kwargs = dict(do_sample=False)
outputs = infer_engine.generate(input_ids, generate_kwargs)
if not dist.is_initialized() or dist.get_rank() == 0:
output_text = tokenizer.decode(outputs[0])
print(output_text)
def check_engine(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_engine_infer():
spawn(check_engine, TP_SIZE)
if __name__ == '__main__':
test_engine_infer()
Loading…
Cancel
Save