[infer] Infer/llama demo (#4503)

* add

* add infer example

* finish

* finish

* stash

* fix
pull/4509/head
Jianghai 2023-08-24 15:42:41 +08:00 committed by GitHub
parent d20dceb9a3
commit c427366024
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 229 additions and 4 deletions

View File

@ -19,6 +19,7 @@ class LlamaPipelineForwards:
under pipeline setting.
'''
@staticmethod
def llama_model_forward(
self: LlamaModel,
input_ids: torch.LongTensor = None,
@ -169,6 +170,7 @@ class LlamaPipelineForwards:
# always return dict for imediate stage
return {'hidden_states': hidden_states}
@staticmethod
def llama_for_causal_lm_forward(
self: LlamaForCausalLM,
input_ids: torch.LongTensor = None,
@ -276,6 +278,7 @@ class LlamaPipelineForwards:
hidden_states = outputs.get('hidden_states')
return {'hidden_states': hidden_states}
@staticmethod
def llama_for_sequence_classification_forward(
self: LlamaForSequenceClassification,
input_ids: torch.LongTensor = None,
@ -388,6 +391,84 @@ class LlamaPipelineForwards:
return {'hidden_states': hidden_states}
class LlamaInferenceForwards:
"""
This class holds forwards for llama inference.
"""
@staticmethod
def llama_model_forward(
self: LlamaModel,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[
torch.LongTensor] = None, # TODO: this can also be removed if we got sin,cos cached in inferinfo
past_key_values: Optional[List[
torch.FloatTensor]] = None, #TODO: maybe removed after memory cache manager is done.
inputs_embeds: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
inferinfo=None,
):
# only keep the basic items
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device)
attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds,
past_key_values_length)
hidden_states = inputs_embeds
for idx, decoder_layer in enumerate(self.layers):
past_key_value = past_key_values[idx] if past_key_values is not None else None
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
)
hidden_states = layer_outputs[0]
hidden_states = self.norm(hidden_states)
if not return_dict:
return hidden_states
return BaseModelOutputWithPast(last_hidden_state=hidden_states,)
def get_llama_flash_attention_forward():
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb

View File

@ -1,5 +1,6 @@
import importlib
from dataclasses import dataclass
from typing import Optional
import torch.nn as nn
@ -130,6 +131,12 @@ _POLICY_LIST = {
PolicyLocation(file_name="chatglm", class_name="ChatGLMForConditionalGenerationPolicy"),
}
_INFER_POLICY_LIST = {
# LlaMa
"transformers.models.llama.modeling_llama.LlamaModel":
PolicyLocation(file_name="llama", class_name="LlamaModelInferPolicy")
}
def import_policy(policy_location: PolicyLocation) -> Policy:
"""
@ -151,7 +158,7 @@ def _fullname(obj):
return module + '.' + klass.__qualname__
def get_autopolicy(model: nn.Module) -> Policy:
def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) -> Policy:
r"""
Return the auto policy for the model
@ -162,7 +169,10 @@ def get_autopolicy(model: nn.Module) -> Policy:
:class:`Policy`: The auto policy for the model
"""
full_name = _fullname(model)
policy_location = _POLICY_LIST.get(full_name, None)
if inference_only:
policy_location = _INFER_POLICY_LIST.get(full_name, None)
else:
policy_location = _POLICY_LIST.get(full_name, None)
if policy_location is None:
raise NotImplementedError(

View File

@ -7,7 +7,7 @@ from torch.nn import Module
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward
from ..modeling.llama import LlamaInferenceForwards, LlamaPipelineForwards, get_llama_flash_attention_forward
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy']
@ -263,3 +263,21 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in llama for sequence classification model"""
return []
class LlamaModelInferPolicy(LlamaPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel
policy = super().module_policy()
# configure default shard config for inference
self.shard_config._infer()
infer_forward = LlamaInferenceForwards.llama_model_forward
method_replacement = {'forward': partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel)
return policy

View File

@ -28,6 +28,7 @@ class ShardConfig:
enable_all_optimization: bool = False
enable_flash_attention: bool = False
enable_jit_fused: bool = False
inference_only: bool = False
# pipeline_parallel_size: int
# data_parallel_size: int
@ -57,3 +58,9 @@ class ShardConfig:
self.enable_fused_normalization = True
self.enable_flash_attention = True
self.enable_jit_fused = True
def _infer(self):
"""
Set default params for inference.
"""
self.pipeline_stage_manager = None

View File

@ -27,7 +27,8 @@ class ModelSharder(object):
def __init__(self, model: nn.Module, policy: Policy, shard_config: ShardConfig = None) -> None:
self.model = model
self.policy = get_autopolicy(self.model) if policy is None else policy
self.policy = get_autopolicy(self.model, shard_config.inference_only) if policy is None else policy
print(self.policy)
self.shard_config = shard_config
def shard(self) -> List[Dict[int, Tensor]]:

View File

@ -0,0 +1,53 @@
import copy
import torch
import torch.distributed as dist
from torch import Tensor
from torch import distributed as dist
from torch.distributed import ProcessGroup
from torch.nn import Module
from torch.optim import Adam, Optimizer
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer._utils import getattr_
from colossalai.shardformer.policies.auto_policy import Policy
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
def build_model(
model_fn,
enable_fused_normalization=False,
enable_tensor_parallelism=False,
enable_flash_attention=False,
enable_jit_fused=False,
):
# create new model
org_model = model_fn()
# shard model
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
enable_tensor_parallelism=enable_tensor_parallelism,
enable_flash_attention=enable_flash_attention,
enable_jit_fused=enable_jit_fused,
inference_only=True)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
sharded_model, shared_params = shard_former.optimize(model_copy)
return org_model.cuda(), sharded_model.cuda()
def run_infer(original_model, sharded_model, data_gen_fn, output_transform_fn):
# prepare input
data = data_gen_fn()
data = {k: v.cuda() for k, v in data.items()}
# run forward
org_output = original_model(**data)
org_output = output_transform_fn(org_output)
shard_output = sharded_model(**data)
shard_output = output_transform_fn(shard_output)
return org_output, shard_output

View File

@ -0,0 +1,55 @@
import os
import pytest
import torch
from torch import distributed as dist
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_infer._utils import build_model, run_infer
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
def check_infer(model_fn, data_gen_fn, output_transform_fn, test_config):
org_model, sharded_model = build_model(model_fn, **test_config)
org_output, infer_output = run_infer(org_model, sharded_model, data_gen_fn, output_transform_fn)
print('original output', org_output[0])
print('infer output', infer_output[0])
@parameterize('test_config', [{
'enable_flash_attention': False,
}])
def run_llama_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
if name != "transformers_llama":
continue
check_infer(model_fn, data_gen_fn, output_transform_fn, test_config)
torch.cuda.empty_cache()
def check_llama(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_llama_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_llama():
spawn(check_llama, 1)
if __name__ == "__main__":
test_llama()