diff --git a/colossalai/inference/__init__.py b/colossalai/inference/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/colossalai/inference/tensor_parallel/__init__.py b/colossalai/inference/tensor_parallel/__init__.py new file mode 100644 index 000000000..e467b4c73 --- /dev/null +++ b/colossalai/inference/tensor_parallel/__init__.py @@ -0,0 +1,4 @@ +from .engine import TPInferEngine +from .kvcache_manager import MemoryManager + +__all__ = ['MemoryManager', 'TPInferEngine'] diff --git a/colossalai/shardformer/inference/batch_infer_state.py b/colossalai/inference/tensor_parallel/batch_infer_state.py similarity index 89% rename from colossalai/shardformer/inference/batch_infer_state.py rename to colossalai/inference/tensor_parallel/batch_infer_state.py index fef23a584..2bff93172 100644 --- a/colossalai/shardformer/inference/batch_infer_state.py +++ b/colossalai/inference/tensor_parallel/batch_infer_state.py @@ -21,6 +21,7 @@ class BatchInferState: block_loc: torch.Tensor = None start_loc: torch.Tensor = None seq_len: torch.Tensor = None + past_key_values_len: int = None is_context_stage: bool = False context_mem_index: torch.Tensor = None @@ -34,7 +35,9 @@ class BatchInferState: @property def total_token_num(self): - return self.batch_size * self.max_len_in_batch + # return self.batch_size * self.max_len_in_batch + assert self.seq_len is not None and self.seq_len.size(0) > 0 + return int(torch.sum(self.seq_len)) def set_cache_manager(self, manager: MemoryManager): self.cache_manager = manager diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py new file mode 100644 index 000000000..f643d892a --- /dev/null +++ b/colossalai/inference/tensor_parallel/engine.py @@ -0,0 +1,254 @@ +from typing import Any, Callable, Dict, List, Optional, Set, Union + +import torch +import torch.nn as nn +from transformers import BloomForCausalLM, LlamaForCausalLM +from transformers.generation import GenerationConfig +from transformers.generation.stopping_criteria import StoppingCriteriaList +from transformers.tokenization_utils_base import BatchEncoding + +from colossalai.cluster import ProcessGroupMesh +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.shardformer.policies.auto_policy import get_autopolicy + +from .batch_infer_state import BatchInferState +from .kvcache_manager import MemoryManager + +DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 + +_supported_models = ['LlamaForCausalLM', 'BloomForCausalLM'] + + +class TPInferEngine: + + def __init__(self, + model: nn.Module, + max_batch_size: int, + max_input_len: int, + max_output_len: int, + dtype: torch.dtype = torch.float16, + device: torch.device = torch.cuda.current_device()) -> None: + self.model = model + self.sharded_model = None + + self.max_batch_size = max_batch_size + self.max_input_len = max_input_len + self.max_output_len = max_output_len + self.max_total_token_num = self.max_batch_size * (self.max_input_len + self.max_output_len) + + # Constraints relatable with specs of devices + assert self.max_batch_size <= 64, "Max batch size exceeds the constraint" + assert self.max_input_len + self.max_output_len <= 2048, "Max length exceeds the constraint" + + self.device = device + self.dtype = dtype + + self.head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads + self.head_num = self.model.config.num_attention_heads + self.layer_num = self.model.config.num_hidden_layers + + self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config + self.cache_manager = None + + def _init_manager(self) -> None: + assert self.tp_size >= 1, "TP size not initialized without providing a valid ShardConfig" + assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}" + self.head_num //= self.tp_size # update sharded number of heads + self.cache_manager = MemoryManager(self.max_total_token_num, self.dtype, self.head_num, self.head_dim, + self.layer_num) + + def prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None) -> ShardConfig: + """ Prepare the engine with a given ShardConfig, or create a default one with tp size 1 """ + self.tp_size = 1 + if shard_config is None: + shard_config = ShardConfig( + tensor_parallel_process_group=None, + pipeline_stage_manager=None, + enable_tensor_parallelism=False, + enable_fused_normalization=False, + enable_all_optimization=False, + enable_flash_attention=False, + enable_jit_fused=False, + inference_only=True, + ) + else: + shard_config.inference_only = True + shard_config.pipeline_stage_manager = None + if shard_config.enable_tensor_parallelism: + self.tp_size = shard_config.tensor_parallel_size + self._init_manager() + + return shard_config + + def shard_model_by(self, shardformer: ShardFormer) -> None: + """ Shard the model and store the sharded model by given ShardFormer """ + assert self.tp_size == shardformer.shard_config.tensor_parallel_size, \ + "Discrepancy between the tp size of TPInferEngine and the tp size of shard config" + model_name = self.model.__class__.__name__ + assert model_name in self._supported_models(), f"Unsupported model cls {model_name} for TP inference." + policy = get_autopolicy(self.model, inference_only=True) + self.sharded_model, _ = shardformer.optimize(self.model, policy) + self.sharded_model = self.sharded_model.to(self.device) + + @staticmethod + def _supported_models() -> List[str]: + return _supported_models + + def generate(self, input_tokens, generate_kwargs) -> torch.Tensor: + if isinstance(input_tokens, torch.Tensor): + input_tokens = dict(input_ids=input_tokens, attention_mask=torch.ones_like(input_tokens, dtype=torch.bool)) + if self.sharded_model is not None: + return self.generate_by_set_infer_state(input_tokens, generate_kwargs) + + return self.model.generate(**input_tokens, **generate_kwargs) + + @torch.no_grad() + def generate_by_set_infer_state(self, input_tokens, generate_kwargs) -> torch.Tensor: + """ + Generate output tokens by setting BatchInferState as an attribute to the model and calling model.generate + + Args: + inputs: should be one of the following types + 1. BatchEncoding or dict (e.g. tokenizer batch_encode) + 2. list of input token ids (e.g. appended result of tokenizer encode) + 3. torch.Tensor (e.g. tokenizer encode with return_tensors='pt') + """ + + # for testing, always use sharded model + assert self.sharded_model is not None, "sharded model does not exist" + + batch_infer_state = self.prepare_batch_state(input_tokens) + assert batch_infer_state.max_len_in_batch <= self.max_input_len, "max length in batch exceeds limit" + + # set BatchInferState for the current batch as attr to model + # NOTE this is not an expectable way to pass BatchInferState during inference + # we might want to rewrite generate function (e.g. generate_by_pass_infer_state) + # and pass BatchInferState via model forward + model = self.sharded_model + if isinstance(model, LlamaForCausalLM): + model = self.sharded_model.model + elif isinstance(model, BloomForCausalLM): + model = self.sharded_model.transformer + setattr(model, 'infer_state', batch_infer_state) + + generate_kwargs.update(max_new_tokens=self.max_output_len) + + if isinstance(input_tokens, torch.Tensor): + input_tokens = dict(input_ids=input_tokens) + for t in input_tokens: + if torch.is_tensor(input_tokens[t]): + input_tokens[t] = input_tokens[t].to(self.device) + + outputs = self.sharded_model.generate(**input_tokens, **generate_kwargs, early_stopping=False) + + print(f"outputs.shape {outputs.shape}") + return outputs + + def prepare_batch_state(self, inputs) -> BatchInferState: + """ + Create and prepare BatchInferState used for inference during model forwrad, + by processing each sequence of the given inputs + + Args: + inputs: should be one of the following types + 1. BatchEncoding or dict (e.g. tokenizer batch_encode) + 2. list of input token ids (e.g. appended result of tokenizer encode) + 3. torch.Tensor (e.g. tokenizer encode with return_tensors='pt') + NOTE For torch.Tensor inputs representing a batch of inputs, we are unable to retrieve + the actual length (e.g. number of tokens) of each input without attention mask + Hence, for torch.Tensor with shape [bs, l] where bs > 1, we will assume + all the inputs in the batch has the maximum length l + Returns: + BatchInferState: the states for the current batch during inference + """ + if not isinstance(inputs, (BatchEncoding, dict, list, torch.Tensor)): + raise TypeError(f"inputs type {type(inputs)} is not supported in prepare_batch_state") + + if isinstance(inputs, (BatchEncoding, dict)): + attn_masks = inputs['attention_mask'] + batch_size = attn_masks.shape[0] + max_len_in_batch = attn_masks.shape[1] + elif isinstance(inputs, list): + batch_size = len(inputs) + else: + batch_size = inputs.shape[0] + + seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device=self.device) + seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device=self.device) + start_index = 0 + + max_len_in_batch = -1 + if isinstance(inputs, (BatchEncoding, dict)): + for i, attn_mask in enumerate(attn_masks): + curr_seq_len = int(torch.sum(attn_mask)) + seq_lengths[i] = curr_seq_len + seq_start_indexes[i] = start_index + start_index += curr_seq_len + max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch + else: + for i, input_ids in enumerate(inputs): + curr_seq_len = len(input_ids) + seq_lengths[i] = curr_seq_len + seq_start_indexes[i] = start_index + start_index += curr_seq_len + max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch + + print(" 666 ", max_len_in_batch) + + block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), + dtype=torch.long, + device=self.device) + batch_infer_state = BatchInferState(batch_size, max_len_in_batch) + batch_infer_state.seq_len = seq_lengths.to(self.device) # might want to assign specific device + batch_infer_state.start_loc = seq_start_indexes.to(self.device) + batch_infer_state.block_loc = block_loc + batch_infer_state.decode_layer_id = 0 + batch_infer_state.past_key_values_len = 0 + batch_infer_state.is_context_stage = True + batch_infer_state.set_cache_manager(self.cache_manager) + return batch_infer_state + + # TODO might want to implement the func that generates output tokens by passing BatchInferState + # as an arg into model.forward + # requires rewriting model generate and replacing model forward + @torch.no_grad() + def generate_by_pass_infer_state(self, + input_tokens, + max_out_length: int, + generation_config: Optional[GenerationConfig] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, + **model_kwargs) -> torch.Tensor: + # if batch_size >= 4: + # assert self.sharded_model is not None, "sharded model does not exist" + # batch_infer_state = self.prepare_batch_state(input_tokens) + # batch_size = batch_infer_state.batch_size + # assert batch_infer_state.max_len_in_batch <= self.max_input_len + # # record sequences finish status, add early stopping, etc, + # for _ in range(min(max_out_length, self.max_output_len)): + # # ... + # self.sharded_model.forward(..., **model_kwargs) + # else: + # Use original model to generate + raise NotImplementedError("generate by passing BatchInferState is not implemented.") + + # NOTE might want to use in rewritten generate method: use after model.forward + # BatchInferState is created and kept during generation + # after each iter of model forward, we should update BatchInferState + def update_batch_state(self, infer_state: Optional[BatchInferState]) -> None: + batch_size = infer_state.batch_size + device = infer_state.start_loc.device + infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device=device) + infer_state.seq_len += 1 + + # TODO might want to create a sequence pool + # add a single request/sequence/input text at a time and record its length + # In other words, store the actual length of input tokens representing a single input text + # E.g. "Introduce landmarks in Beijing" + # => add request + # => record token length and other necessary information to be used + # => engine hold all these necessary information until `generate` (or other name) is called, + # => put information already recorded in batchinferstate and pass it to model forward + # => clear records in engine + def add_request(): + raise NotImplementedError() diff --git a/colossalai/shardformer/inference/kvcache_manager.py b/colossalai/inference/tensor_parallel/kvcache_manager.py similarity index 100% rename from colossalai/shardformer/inference/kvcache_manager.py rename to colossalai/inference/tensor_parallel/kvcache_manager.py diff --git a/colossalai/shardformer/inference/__init__.py b/colossalai/shardformer/inference/__init__.py deleted file mode 100644 index 1bce92653..000000000 --- a/colossalai/shardformer/inference/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .batch_infer_state import BatchInferState -from .kvcache_manager import MemoryManager - -__all__ = ['BatchInferState', 'MemoryManager'] diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index 43ea1c5ab..0ffa7fbee 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -134,7 +134,9 @@ _POLICY_LIST = { _INFER_POLICY_LIST = { # LlaMa "transformers.models.llama.modeling_llama.LlamaModel": - PolicyLocation(file_name="llama", class_name="LlamaModelInferPolicy") + PolicyLocation(file_name="llama", class_name="LlamaModelInferPolicy"), + "transformers.models.llama.modeling_llama.LlamaForCausalLM": + PolicyLocation(file_name="llama", class_name="LlamaModelInferPolicy"), } diff --git a/tests/test_infer/test_infer_engine.py b/tests/test_infer/test_infer_engine.py new file mode 100644 index 000000000..7fcb36554 --- /dev/null +++ b/tests/test_infer/test_infer_engine.py @@ -0,0 +1,70 @@ +import pytest +import torch +from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM + +import colossalai +from colossalai.inference.tensor_parallel import TPInferEngine +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn + +TP_SIZE = 2 +BATCH_SIZE = 4 +MAX_INPUT_LEN = 16 +MAX_OUTPUT_LEN = 8 + + +def test_orig_generate(): + input_ids = torch.randint(low=10, high=1000, size=(BATCH_SIZE, MAX_INPUT_LEN)) + + model_config = LlamaConfig() + model = LlamaForCausalLM(model_config) + shard_config = ShardConfig(enable_tensor_parallelism=False) + + # init TPInferEngine and + infer_engine = TPInferEngine(model, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + infer_engine.prepare_with_shard_config(shard_config) + + # original model generate + generate_kwargs = dict(do_sample=False) + infer_engine.generate(input_ids, generate_kwargs) + + +def run(): + model_config = LlamaConfig() + model = LlamaForCausalLM(model_config) + shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True) + shardformer = ShardFormer(shard_config=shard_config) + + infer_engine = TPInferEngine(model, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + infer_engine.prepare_with_shard_config(shard_config=shard_config) + infer_engine.shard_model_by(shardformer) + + assert infer_engine.cache_manager is not None + assert infer_engine.tp_size == TP_SIZE + assert infer_engine.head_num == model_config.num_attention_heads // TP_SIZE + + # TODO After adding forward replacement for CausalLM, + # uncomment these lines to test sharded model generate + # generate_kwargs = dict(do_sample=False) + # infer_engine.generate(input_ids, generate_kwargs) + + torch.cuda.empty_cache() + + +def check_engine(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_engine_infer(): + spawn(check_engine, TP_SIZE) + + +if __name__ == '__main__': + test_orig_generate() + test_engine_infer() diff --git a/tests/test_infer/test_kvcache_manager.py b/tests/test_infer/test_kvcache_manager.py index ef48444f7..fb04d7800 100644 --- a/tests/test_infer/test_kvcache_manager.py +++ b/tests/test_infer/test_kvcache_manager.py @@ -3,8 +3,8 @@ import os import pytest import torch +from colossalai.inference.tensor_parallel import MemoryManager from colossalai.logging import disable_existing_loggers -from colossalai.shardformer.inference import MemoryManager from colossalai.testing import rerun_if_address_is_in_use, spawn BATCH_SIZE = 4