mirror of https://github.com/hpcaitech/ColossalAI
[infer] Infer/llama demo (#4503)
* add * add infer example * finish * finish * stash * fixpull/4509/head
parent
d20dceb9a3
commit
c427366024
|
@ -19,6 +19,7 @@ class LlamaPipelineForwards:
|
|||
under pipeline setting.
|
||||
'''
|
||||
|
||||
@staticmethod
|
||||
def llama_model_forward(
|
||||
self: LlamaModel,
|
||||
input_ids: torch.LongTensor = None,
|
||||
|
@ -169,6 +170,7 @@ class LlamaPipelineForwards:
|
|||
# always return dict for imediate stage
|
||||
return {'hidden_states': hidden_states}
|
||||
|
||||
@staticmethod
|
||||
def llama_for_causal_lm_forward(
|
||||
self: LlamaForCausalLM,
|
||||
input_ids: torch.LongTensor = None,
|
||||
|
@ -276,6 +278,7 @@ class LlamaPipelineForwards:
|
|||
hidden_states = outputs.get('hidden_states')
|
||||
return {'hidden_states': hidden_states}
|
||||
|
||||
@staticmethod
|
||||
def llama_for_sequence_classification_forward(
|
||||
self: LlamaForSequenceClassification,
|
||||
input_ids: torch.LongTensor = None,
|
||||
|
@ -388,6 +391,84 @@ class LlamaPipelineForwards:
|
|||
return {'hidden_states': hidden_states}
|
||||
|
||||
|
||||
class LlamaInferenceForwards:
|
||||
"""
|
||||
This class holds forwards for llama inference.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def llama_model_forward(
|
||||
self: LlamaModel,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[
|
||||
torch.LongTensor] = None, # TODO: this can also be removed if we got sin,cos cached in inferinfo
|
||||
past_key_values: Optional[List[
|
||||
torch.FloatTensor]] = None, #TODO: maybe removed after memory cache manager is done.
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
inferinfo=None,
|
||||
):
|
||||
# only keep the basic items
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
# embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones((batch_size, seq_length_with_past),
|
||||
dtype=torch.bool,
|
||||
device=inputs_embeds.device)
|
||||
attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds,
|
||||
past_key_values_length)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
return hidden_states
|
||||
return BaseModelOutputWithPast(last_hidden_state=hidden_states,)
|
||||
|
||||
|
||||
def get_llama_flash_attention_forward():
|
||||
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import importlib
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
@ -130,6 +131,12 @@ _POLICY_LIST = {
|
|||
PolicyLocation(file_name="chatglm", class_name="ChatGLMForConditionalGenerationPolicy"),
|
||||
}
|
||||
|
||||
_INFER_POLICY_LIST = {
|
||||
# LlaMa
|
||||
"transformers.models.llama.modeling_llama.LlamaModel":
|
||||
PolicyLocation(file_name="llama", class_name="LlamaModelInferPolicy")
|
||||
}
|
||||
|
||||
|
||||
def import_policy(policy_location: PolicyLocation) -> Policy:
|
||||
"""
|
||||
|
@ -151,7 +158,7 @@ def _fullname(obj):
|
|||
return module + '.' + klass.__qualname__
|
||||
|
||||
|
||||
def get_autopolicy(model: nn.Module) -> Policy:
|
||||
def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) -> Policy:
|
||||
r"""
|
||||
Return the auto policy for the model
|
||||
|
||||
|
@ -162,7 +169,10 @@ def get_autopolicy(model: nn.Module) -> Policy:
|
|||
:class:`Policy`: The auto policy for the model
|
||||
"""
|
||||
full_name = _fullname(model)
|
||||
policy_location = _POLICY_LIST.get(full_name, None)
|
||||
if inference_only:
|
||||
policy_location = _INFER_POLICY_LIST.get(full_name, None)
|
||||
else:
|
||||
policy_location = _POLICY_LIST.get(full_name, None)
|
||||
|
||||
if policy_location is None:
|
||||
raise NotImplementedError(
|
||||
|
|
|
@ -7,7 +7,7 @@ from torch.nn import Module
|
|||
|
||||
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
||||
|
||||
from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward
|
||||
from ..modeling.llama import LlamaInferenceForwards, LlamaPipelineForwards, get_llama_flash_attention_forward
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy']
|
||||
|
@ -263,3 +263,21 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
|||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
"""No shared params in llama for sequence classification model"""
|
||||
return []
|
||||
|
||||
|
||||
class LlamaModelInferPolicy(LlamaPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel
|
||||
policy = super().module_policy()
|
||||
# configure default shard config for inference
|
||||
self.shard_config._infer()
|
||||
|
||||
infer_forward = LlamaInferenceForwards.llama_model_forward
|
||||
method_replacement = {'forward': partial(infer_forward)}
|
||||
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel)
|
||||
|
||||
return policy
|
||||
|
|
|
@ -28,6 +28,7 @@ class ShardConfig:
|
|||
enable_all_optimization: bool = False
|
||||
enable_flash_attention: bool = False
|
||||
enable_jit_fused: bool = False
|
||||
inference_only: bool = False
|
||||
|
||||
# pipeline_parallel_size: int
|
||||
# data_parallel_size: int
|
||||
|
@ -57,3 +58,9 @@ class ShardConfig:
|
|||
self.enable_fused_normalization = True
|
||||
self.enable_flash_attention = True
|
||||
self.enable_jit_fused = True
|
||||
|
||||
def _infer(self):
|
||||
"""
|
||||
Set default params for inference.
|
||||
"""
|
||||
self.pipeline_stage_manager = None
|
||||
|
|
|
@ -27,7 +27,8 @@ class ModelSharder(object):
|
|||
|
||||
def __init__(self, model: nn.Module, policy: Policy, shard_config: ShardConfig = None) -> None:
|
||||
self.model = model
|
||||
self.policy = get_autopolicy(self.model) if policy is None else policy
|
||||
self.policy = get_autopolicy(self.model, shard_config.inference_only) if policy is None else policy
|
||||
print(self.policy)
|
||||
self.shard_config = shard_config
|
||||
|
||||
def shard(self) -> List[Dict[int, Tensor]]:
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
import copy
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
from torch import distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn import Module
|
||||
from torch.optim import Adam, Optimizer
|
||||
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import HybridParallelPlugin
|
||||
from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.shardformer._utils import getattr_
|
||||
from colossalai.shardformer.policies.auto_policy import Policy
|
||||
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||
|
||||
|
||||
def build_model(
|
||||
model_fn,
|
||||
enable_fused_normalization=False,
|
||||
enable_tensor_parallelism=False,
|
||||
enable_flash_attention=False,
|
||||
enable_jit_fused=False,
|
||||
):
|
||||
# create new model
|
||||
org_model = model_fn()
|
||||
|
||||
# shard model
|
||||
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
|
||||
enable_tensor_parallelism=enable_tensor_parallelism,
|
||||
enable_flash_attention=enable_flash_attention,
|
||||
enable_jit_fused=enable_jit_fused,
|
||||
inference_only=True)
|
||||
model_copy = copy.deepcopy(org_model)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
sharded_model, shared_params = shard_former.optimize(model_copy)
|
||||
return org_model.cuda(), sharded_model.cuda()
|
||||
|
||||
|
||||
def run_infer(original_model, sharded_model, data_gen_fn, output_transform_fn):
|
||||
# prepare input
|
||||
data = data_gen_fn()
|
||||
data = {k: v.cuda() for k, v in data.items()}
|
||||
# run forward
|
||||
org_output = original_model(**data)
|
||||
org_output = output_transform_fn(org_output)
|
||||
|
||||
shard_output = sharded_model(**data)
|
||||
shard_output = output_transform_fn(shard_output)
|
||||
|
||||
return org_output, shard_output
|
|
@ -0,0 +1,55 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer.layer.utils import Randomizer
|
||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_infer._utils import build_model, run_infer
|
||||
|
||||
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
||||
|
||||
|
||||
def check_infer(model_fn, data_gen_fn, output_transform_fn, test_config):
|
||||
org_model, sharded_model = build_model(model_fn, **test_config)
|
||||
|
||||
org_output, infer_output = run_infer(org_model, sharded_model, data_gen_fn, output_transform_fn)
|
||||
|
||||
print('original output', org_output[0])
|
||||
print('infer output', infer_output[0])
|
||||
|
||||
|
||||
@parameterize('test_config', [{
|
||||
'enable_flash_attention': False,
|
||||
}])
|
||||
def run_llama_test(test_config):
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
if name != "transformers_llama":
|
||||
continue
|
||||
check_infer(model_fn, data_gen_fn, output_transform_fn, test_config)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_llama(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_llama_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_llama():
|
||||
spawn(check_llama, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_llama()
|
Loading…
Reference in New Issue