diff --git a/colossalai/inference/modeling/backends/attention_backend.py b/colossalai/inference/modeling/backends/attention_backend.py index e0a4ec33d..ab586f510 100644 --- a/colossalai/inference/modeling/backends/attention_backend.py +++ b/colossalai/inference/modeling/backends/attention_backend.py @@ -2,7 +2,6 @@ from abc import ABC, abstractmethod from dataclasses import dataclass import torch -from flash_attn import flash_attn_varlen_func from colossalai.inference.config import ModelShardInferenceConfig from colossalai.kernel.kernel_loader import InferenceOpsLoader @@ -44,7 +43,7 @@ class CudaAttentionBackend(AttentionBackend): it uses Triton op `context_attention_unpadded` for prefilling and our cuda op `flash_decoding_attention` for decoding. """ - def __init__(self, use_flash_attn: bool): + def __init__(self, use_flash_attn: bool = False): super().__init__() self.inference_ops = InferenceOpsLoader().load() self.use_flash_attn = use_flash_attn @@ -52,6 +51,9 @@ class CudaAttentionBackend(AttentionBackend): def prefill(self, attn_metadata: AttentionMetaData, **kwargs): if self.use_flash_attn: token_nums = kwargs.get("token_nums", -1) + + from flash_attn import flash_attn_varlen_func + attn_output = flash_attn_varlen_func( attn_metadata.query_states, attn_metadata.key_states, diff --git a/colossalai/inference/modeling/models/nopadding_baichuan.py b/colossalai/inference/modeling/models/nopadding_baichuan.py index ba3e7b4e8..3bab671c4 100644 --- a/colossalai/inference/modeling/models/nopadding_baichuan.py +++ b/colossalai/inference/modeling/models/nopadding_baichuan.py @@ -200,8 +200,6 @@ class NopadBaichuanAttention(ParallelModule): self.pre_attention_backend.decode( attn_metadata, - cos=cos_sin[0], - sin=cos_sin[1], q_len=q_len, ) attn_output = self.attention_backend.decode( diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index e274e7b7c..445ec59ce 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -114,7 +114,7 @@ def llama_model_forward( elif use_cuda_kernel: if can_use_flash_attn2(inputmetadata.dtype): - cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0)) + cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.int32), (1, 0)) hidden_dim = self._cos_cached.size(-1) total_length = hidden_states.size(0) @@ -265,7 +265,7 @@ class NopadLlamaMLP(LlamaMLP, ParallelModule): mlp_dproj: ParallelModule = None, process_group: ProcessGroup = None, ): - """A Unified Layer for + """Replacement of LlamaMLP layer. Args: config (LlamaConfig): Holding the Llama model config. diff --git a/colossalai/inference/utils.py b/colossalai/inference/utils.py index 1374103a9..8c155e6ca 100644 --- a/colossalai/inference/utils.py +++ b/colossalai/inference/utils.py @@ -152,6 +152,8 @@ def can_use_flash_attn2(dtype: torch.dtype) -> bool: return False try: + from flash_attn import flash_attn_varlen_func # noqa + return True except ImportError: logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.") diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index abc865a34..141baf3d3 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -50,7 +50,7 @@ def get_pad_info(padding_mask: torch.Tensor) -> Tuple[int, torch.Tensor, torch.T seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) return max_seqlen_in_batch, cu_seqlens, indices diff --git a/tests/test_infer/test_kernels/cuda/test_kv_cache_memcpy.py b/tests/test_infer/test_kernels/cuda/test_kv_cache_memcpy.py index d90f64690..c3f2d0144 100644 --- a/tests/test_infer/test_kernels/cuda/test_kv_cache_memcpy.py +++ b/tests/test_infer/test_kernels/cuda/test_kv_cache_memcpy.py @@ -26,7 +26,7 @@ def prepare_data( num_tokens = torch.sum(context_lengths).item() max_seq_len_in_batch = context_lengths.max() - cu_seqlens = F.pad(torch.cumsum(context_lengths, dim=0, dtype=torch.torch.int32), (1, 0)) + cu_seqlens = F.pad(torch.cumsum(context_lengths, dim=0, dtype=torch.int32), (1, 0)) kv_size = (num_tokens, num_kv_heads, HEAD_DIM) key = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) diff --git a/tests/test_infer/test_models/test_custom_model.py b/tests/test_infer/test_models/test_custom_model.py new file mode 100644 index 000000000..f78731acf --- /dev/null +++ b/tests/test_infer/test_models/test_custom_model.py @@ -0,0 +1,161 @@ +import os +import random + +import numpy as np +import pytest +import torch +import torch.distributed as dist +from torch.multiprocessing import Manager +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, LlamaForCausalLM, LlamaTokenizer + +import colossalai +import colossalai.inference.modeling.policy as policy +from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig +from colossalai.inference.core.engine import InferenceEngine +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn + +# NOTE: To test a model with the inference engine, you need to provide the path to your +# local pretrained model weights in the MODEL_MAP dictionary +MODEL_MAP = { + "baichuan": { + "model": AutoModelForCausalLM, + "tokenizer": AutoTokenizer, + "policy": policy.NoPaddingBaichuanModelInferPolicy, + "model_name_or_path": "baichuan-inc/Baichuan2-13B-Base", # provide the path to local model weights + }, + "llama": { + "model": LlamaForCausalLM, + "tokenizer": LlamaTokenizer, + "policy": policy.NoPaddingLlamaModelInferPolicy, + "model_name_or_path": "meta-llama/Llama-2-70b-hf", + }, +} + +MODELS_TO_TEST = ["llama", "baichuan"] # Specify the models to test + + +@parameterize("model", MODELS_TO_TEST) +@parameterize("prompt_template", [None, "model_specific"]) +@parameterize("do_sample", [False]) +@parameterize("use_cuda_kernel", [True]) +@pytest.mark.largedist +@rerun_if_address_is_in_use() +def test_model(model, prompt_template, do_sample, use_cuda_kernel): + model_path = MODEL_MAP[model]["model_name_or_path"] + if not os.path.exists(model_path): + pytest.skip( + f"There is no local model address included for {model}, please replace this address with a valid one." + ) + + if prompt_template == "model_specific": + prompt_template = model + + model_config = MODEL_MAP[model] + + kwargs1 = { + "model": model, + "use_engine": True, + "prompt_template": prompt_template, + "do_sample": do_sample, + "policy": model_config["policy"](), + "use_cuda_kernel": use_cuda_kernel, + } + + kwargs2 = { + "model": model, + "use_engine": False, + "prompt_template": prompt_template, + "do_sample": do_sample, + "policy": None, + "use_cuda_kernel": use_cuda_kernel, + } + + colossal_tp_1_output = run_engine(1, **kwargs1) + colossal_tp_2_output = run_engine(2, **kwargs1) + transformer_tp_1_output = run_engine(1, **kwargs2) + + for s1, s2, s3 in zip(colossal_tp_1_output, colossal_tp_2_output, transformer_tp_1_output): + assert s1 == s3, f"\nColossalAI TP=1 Output: {s1}\nTransformers Output: {s3}" + assert s1 == s2, f"\nColossalAI TP=1 Output: {s1}\nColossalAI TP=2 Output: {s2}" + + +def run_engine(world_size, **kwargs): + manager = Manager() + result_list = manager.list([-1] * world_size) # Create a shared list + spawn(run_dist, world_size, func_to_run=_run_engine, ret=result_list, **kwargs) + return result_list[0] + + +def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs): + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") + + if ret: + ret[rank] = func_to_run(**kwargs) + else: + func_to_run(**kwargs) + + +def _run_engine(model, use_engine=False, do_sample=False, use_cuda_kernel=False, prompt_template=None, policy=None): + setup_seed(20) + model_config = MODEL_MAP[model] + model_name_or_path = model_config["model_name_or_path"] + tokenizer = model_config["tokenizer"].from_pretrained(model_name_or_path, use_fast=False, trust_remote_code=True) + model = model_config["model"].from_pretrained(model_name_or_path, trust_remote_code=True).half().cuda() + model = model.eval() + + inputs = [ + "Introduce some landmarks in Paris:", + ] + + output_len = 38 + + if do_sample: + top_p = 0.5 + top_k = 50 + else: + top_p = None + top_k = None + + if use_engine: + inference_config = InferenceConfig( + max_output_len=output_len, + prompt_template=prompt_template, + use_cuda_kernel=use_cuda_kernel, + tp_size=dist.get_world_size(), + ) + inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True, model_policy=policy) + assert inference_engine.generation_config.max_new_tokens == output_len + inference_engine.add_request(prompts=inputs) + assert inference_engine.request_handler._has_waiting() + generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k, max_new_tokens=output_len) + outputs = inference_engine.generate(generation_config=generation_config) + else: + if prompt_template: + # apply prompt template + inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs] + tokenizer.pad_token = tokenizer.eos_token + tokenizer.pad_token_id = tokenizer.eos_token_id + inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"] + inputs = inputs.cuda() + generation_config = GenerationConfig( + do_sample=do_sample, + top_p=top_p, + top_k=top_k, + pad_token_id=tokenizer.pad_token_id, + max_new_tokens=output_len, + ) + outputs = model.generate(inputs, generation_config=generation_config) + outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) + return outputs + + +def setup_seed(seed): + torch.manual_seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + + +if __name__ == "__main__": + test_model()