Browse Source

[Inference] Fix flash-attn import and add model test (#5794)

* Fix torch int32 dtype

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Fix flash-attn import

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Add generalized model test

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Remove exposed path to model

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Add default value for use_flash_attn

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Rename model test

Signed-off-by: char-1ee <xingjianli59@gmail.com>

---------

Signed-off-by: char-1ee <xingjianli59@gmail.com>
pull/5803/head
Li Xingjian 5 months ago committed by GitHub
parent
commit
8554585a5f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 6
      colossalai/inference/modeling/backends/attention_backend.py
  2. 2
      colossalai/inference/modeling/models/nopadding_baichuan.py
  3. 4
      colossalai/inference/modeling/models/nopadding_llama.py
  4. 2
      colossalai/inference/utils.py
  5. 2
      colossalai/shardformer/layer/attn.py
  6. 2
      tests/test_infer/test_kernels/cuda/test_kv_cache_memcpy.py
  7. 161
      tests/test_infer/test_models/test_custom_model.py

6
colossalai/inference/modeling/backends/attention_backend.py

@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass
import torch
from flash_attn import flash_attn_varlen_func
from colossalai.inference.config import ModelShardInferenceConfig
from colossalai.kernel.kernel_loader import InferenceOpsLoader
@ -44,7 +43,7 @@ class CudaAttentionBackend(AttentionBackend):
it uses Triton op `context_attention_unpadded` for prefilling and our cuda op `flash_decoding_attention` for decoding.
"""
def __init__(self, use_flash_attn: bool):
def __init__(self, use_flash_attn: bool = False):
super().__init__()
self.inference_ops = InferenceOpsLoader().load()
self.use_flash_attn = use_flash_attn
@ -52,6 +51,9 @@ class CudaAttentionBackend(AttentionBackend):
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
if self.use_flash_attn:
token_nums = kwargs.get("token_nums", -1)
from flash_attn import flash_attn_varlen_func
attn_output = flash_attn_varlen_func(
attn_metadata.query_states,
attn_metadata.key_states,

2
colossalai/inference/modeling/models/nopadding_baichuan.py

@ -200,8 +200,6 @@ class NopadBaichuanAttention(ParallelModule):
self.pre_attention_backend.decode(
attn_metadata,
cos=cos_sin[0],
sin=cos_sin[1],
q_len=q_len,
)
attn_output = self.attention_backend.decode(

4
colossalai/inference/modeling/models/nopadding_llama.py

@ -114,7 +114,7 @@ def llama_model_forward(
elif use_cuda_kernel:
if can_use_flash_attn2(inputmetadata.dtype):
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.int32), (1, 0))
hidden_dim = self._cos_cached.size(-1)
total_length = hidden_states.size(0)
@ -265,7 +265,7 @@ class NopadLlamaMLP(LlamaMLP, ParallelModule):
mlp_dproj: ParallelModule = None,
process_group: ProcessGroup = None,
):
"""A Unified Layer for
"""Replacement of LlamaMLP layer.
Args:
config (LlamaConfig): Holding the Llama model config.

2
colossalai/inference/utils.py

@ -152,6 +152,8 @@ def can_use_flash_attn2(dtype: torch.dtype) -> bool:
return False
try:
from flash_attn import flash_attn_varlen_func # noqa
return True
except ImportError:
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")

2
colossalai/shardformer/layer/attn.py

@ -50,7 +50,7 @@ def get_pad_info(padding_mask: torch.Tensor) -> Tuple[int, torch.Tensor, torch.T
seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return max_seqlen_in_batch, cu_seqlens, indices

2
tests/test_infer/test_kernels/cuda/test_kv_cache_memcpy.py

@ -26,7 +26,7 @@ def prepare_data(
num_tokens = torch.sum(context_lengths).item()
max_seq_len_in_batch = context_lengths.max()
cu_seqlens = F.pad(torch.cumsum(context_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
cu_seqlens = F.pad(torch.cumsum(context_lengths, dim=0, dtype=torch.int32), (1, 0))
kv_size = (num_tokens, num_kv_heads, HEAD_DIM)
key = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)

161
tests/test_infer/test_models/test_custom_model.py

@ -0,0 +1,161 @@
import os
import random
import numpy as np
import pytest
import torch
import torch.distributed as dist
from torch.multiprocessing import Manager
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, LlamaForCausalLM, LlamaTokenizer
import colossalai
import colossalai.inference.modeling.policy as policy
from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
# NOTE: To test a model with the inference engine, you need to provide the path to your
# local pretrained model weights in the MODEL_MAP dictionary
MODEL_MAP = {
"baichuan": {
"model": AutoModelForCausalLM,
"tokenizer": AutoTokenizer,
"policy": policy.NoPaddingBaichuanModelInferPolicy,
"model_name_or_path": "baichuan-inc/Baichuan2-13B-Base", # provide the path to local model weights
},
"llama": {
"model": LlamaForCausalLM,
"tokenizer": LlamaTokenizer,
"policy": policy.NoPaddingLlamaModelInferPolicy,
"model_name_or_path": "meta-llama/Llama-2-70b-hf",
},
}
MODELS_TO_TEST = ["llama", "baichuan"] # Specify the models to test
@parameterize("model", MODELS_TO_TEST)
@parameterize("prompt_template", [None, "model_specific"])
@parameterize("do_sample", [False])
@parameterize("use_cuda_kernel", [True])
@pytest.mark.largedist
@rerun_if_address_is_in_use()
def test_model(model, prompt_template, do_sample, use_cuda_kernel):
model_path = MODEL_MAP[model]["model_name_or_path"]
if not os.path.exists(model_path):
pytest.skip(
f"There is no local model address included for {model}, please replace this address with a valid one."
)
if prompt_template == "model_specific":
prompt_template = model
model_config = MODEL_MAP[model]
kwargs1 = {
"model": model,
"use_engine": True,
"prompt_template": prompt_template,
"do_sample": do_sample,
"policy": model_config["policy"](),
"use_cuda_kernel": use_cuda_kernel,
}
kwargs2 = {
"model": model,
"use_engine": False,
"prompt_template": prompt_template,
"do_sample": do_sample,
"policy": None,
"use_cuda_kernel": use_cuda_kernel,
}
colossal_tp_1_output = run_engine(1, **kwargs1)
colossal_tp_2_output = run_engine(2, **kwargs1)
transformer_tp_1_output = run_engine(1, **kwargs2)
for s1, s2, s3 in zip(colossal_tp_1_output, colossal_tp_2_output, transformer_tp_1_output):
assert s1 == s3, f"\nColossalAI TP=1 Output: {s1}\nTransformers Output: {s3}"
assert s1 == s2, f"\nColossalAI TP=1 Output: {s1}\nColossalAI TP=2 Output: {s2}"
def run_engine(world_size, **kwargs):
manager = Manager()
result_list = manager.list([-1] * world_size) # Create a shared list
spawn(run_dist, world_size, func_to_run=_run_engine, ret=result_list, **kwargs)
return result_list[0]
def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs):
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
if ret:
ret[rank] = func_to_run(**kwargs)
else:
func_to_run(**kwargs)
def _run_engine(model, use_engine=False, do_sample=False, use_cuda_kernel=False, prompt_template=None, policy=None):
setup_seed(20)
model_config = MODEL_MAP[model]
model_name_or_path = model_config["model_name_or_path"]
tokenizer = model_config["tokenizer"].from_pretrained(model_name_or_path, use_fast=False, trust_remote_code=True)
model = model_config["model"].from_pretrained(model_name_or_path, trust_remote_code=True).half().cuda()
model = model.eval()
inputs = [
"Introduce some landmarks in Paris:",
]
output_len = 38
if do_sample:
top_p = 0.5
top_k = 50
else:
top_p = None
top_k = None
if use_engine:
inference_config = InferenceConfig(
max_output_len=output_len,
prompt_template=prompt_template,
use_cuda_kernel=use_cuda_kernel,
tp_size=dist.get_world_size(),
)
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True, model_policy=policy)
assert inference_engine.generation_config.max_new_tokens == output_len
inference_engine.add_request(prompts=inputs)
assert inference_engine.request_handler._has_waiting()
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k, max_new_tokens=output_len)
outputs = inference_engine.generate(generation_config=generation_config)
else:
if prompt_template:
# apply prompt template
inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs]
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
inputs = inputs.cuda()
generation_config = GenerationConfig(
do_sample=do_sample,
top_p=top_p,
top_k=top_k,
pad_token_id=tokenizer.pad_token_id,
max_new_tokens=output_len,
)
outputs = model.generate(inputs, generation_config=generation_config)
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
return outputs
def setup_seed(seed):
torch.manual_seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
if __name__ == "__main__":
test_model()
Loading…
Cancel
Save