diff --git a/colossalai/inference/__init__.py b/colossalai/inference/__init__.py index 8d251e5bf..8cdc6db55 100644 --- a/colossalai/inference/__init__.py +++ b/colossalai/inference/__init__.py @@ -1,4 +1,4 @@ from .hybridengine import CaiInferEngine -from .hybridengine.polices import BloomModelInferPolicy, LlamaModelInferPolicy +from .hybridengine.polices import BloomModelInferPolicy, ChatGLM2InferPolicy, LlamaModelInferPolicy -__all__ = ["CaiInferEngine", "LlamaModelInferPolicy", "BloomModelInferPolicy"] +__all__ = ["CaiInferEngine", "LlamaModelInferPolicy", "BloomModelInferPolicy", "ChatGLM2InferPolicy"] diff --git a/colossalai/inference/hybridengine/engine.py b/colossalai/inference/hybridengine/engine.py index 3a80723c3..5e944014b 100644 --- a/colossalai/inference/hybridengine/engine.py +++ b/colossalai/inference/hybridengine/engine.py @@ -14,8 +14,7 @@ from ..tensor_parallel.kvcache_manager import MemoryManager PP_AXIS, TP_AXIS = 0, 1 -_supported_models = ["LlamaForCausalLM", "BloomForCausalLM", "LlamaGPTQForCausalLM", "SmoothLlamaForCausalLM"] - +_supported_models = ["LlamaForCausalLM", "BloomForCausalLM", "LlamaGPTQForCausalLM", "SmoothLlamaForCausalLM", "ChatGLMForConditionalGeneration"] class CaiInferEngine: """ @@ -184,6 +183,16 @@ class CaiInferEngine: head_num = model.config.n_head // self.tp_size num_hidden_layers = model.config.n_layer layer_num = num_hidden_layers // self.pp_size + elif model.config.model_type == "chatglm": + head_dim = model.config.hidden_size // model.config.num_attention_heads + if model.config.multi_query_attention: + head_num = model.config.multi_query_group_num // self.tp_size + else: + head_num = model.config.num_attention_heads // self.tp_size + num_hidden_layers = model.config.num_layers + layer_num = num_hidden_layers // self.pp_size + else: + raise NotImplementedError("Only support llama, bloom and chatglm model.") if self.quant == "smoothquant": cache_manager = MemoryManager(max_total_token_num, torch.int8, head_num, head_dim, layer_num) diff --git a/colossalai/inference/hybridengine/modeling/chatglm2.py b/colossalai/inference/hybridengine/modeling/chatglm2.py new file mode 100644 index 000000000..0110b9d9a --- /dev/null +++ b/colossalai/inference/hybridengine/modeling/chatglm2.py @@ -0,0 +1,492 @@ +from typing import List, Optional, Tuple + +import torch +from transformers.utils import logging + +from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState +from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer import ShardConfig +from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( + ChatGLMForConditionalGeneration, + ChatGLMModel, + GLMBlock, + GLMTransformer, + SelfAttention, + split_tensor_along_last_dim, +) + +from ._utils import copy_kv_to_mem_cache + +try: + from lightllm.models.chatglm2.triton_kernel.rotary_emb import rotary_emb_fwd as chatglm2_rotary_emb_fwd + from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import ( + context_attention_fwd as lightllm_llama2_context_attention_fwd, + ) + + HAS_LIGHTLLM_KERNEL = True +except: + print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm") + HAS_LIGHTLLM_KERNEL = False + + +def get_masks(self, input_ids, past_length, padding_mask=None): + batch_size, seq_length = input_ids.shape + full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device) + full_attention_mask.tril_() + if past_length: + full_attention_mask = torch.cat( + ( + torch.ones(batch_size, seq_length, past_length, device=input_ids.device), + full_attention_mask, + ), + dim=-1, + ) + + if padding_mask is not None: + full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) + if not past_length and padding_mask is not None: + full_attention_mask -= padding_mask.unsqueeze(-1) - 1 + full_attention_mask = (full_attention_mask < 0.5).bool() + full_attention_mask.unsqueeze_(1) + return full_attention_mask + + +def get_position_ids(batch_size, seq_length, device): + position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) + return position_ids + + +class ChatGLM2InferenceForwards: + """ + This class holds forwards for Chatglm2 inference. + We intend to replace the forward methods for ChatGLMModel, ChatGLMEecoderLayer, and ChatGLMAttention. + """ + + @staticmethod + def chatglm_for_conditional_generation_forward( + self: ChatGLMForConditionalGeneration, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = True, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + return_last_logit: Optional[bool] = False, + infer_state: Optional[BatchInferState] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ): + logger = logging.get_logger(__name__) + + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + # If is first stage and hidden_states is not None, go throught lm_head first + if stage_manager.is_first_stage() and hidden_states is not None: + if return_last_logit: + hidden_states = hidden_states[-1:] + lm_logits = self.transformer.output_layer(hidden_states) + lm_logits = lm_logits.transpose(0, 1).contiguous() + return {"logits": lm_logits} + + outputs = self.transformer( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + infer_state=infer_state, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) + + return outputs + + @staticmethod + def chatglm_model_forward( + self: ChatGLMModel, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + full_attention_mask: Optional[torch.BoolTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + infer_state: BatchInferState = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ): + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + if stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + if inputs_embeds is None: + inputs_embeds = self.embedding(input_ids) + if position_ids is None: + position_ids = get_position_ids(batch_size, seq_length, input_ids.device) + hidden_states = inputs_embeds + else: + assert hidden_states is not None, "hidden_states should not be None in non-first stage" + seq_length, batch_size, _ = hidden_states.shape + if position_ids is None: + position_ids = get_position_ids(batch_size, seq_length, hidden_states.device) + + if infer_state.is_context_stage: + past_key_values_length = 0 + else: + past_key_values_length = infer_state.max_len_in_batch - 1 + + seq_length_with_past = seq_length + past_key_values_length + + # prefill stage at first + if seq_length != 1: + infer_state.is_context_stage = True + infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) + infer_state.init_block_loc( + infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index + ) + else: + infer_state.is_context_stage = False + alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) + if alloc_mem is not None: + infer_state.decode_is_contiguous = True + infer_state.decode_mem_index = alloc_mem[0] + infer_state.decode_mem_start = alloc_mem[1] + infer_state.decode_mem_end = alloc_mem[2] + infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + else: + print(f" *** Encountered allocation non-contiguous") + print( + f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}" + ) + infer_state.decode_is_contiguous = False + alloc_mem = infer_state.cache_manager.alloc(batch_size) + infer_state.decode_mem_index = alloc_mem + infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + + # related to rotary embedding + if infer_state.is_context_stage: + infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view( + position_ids.view(-1).shape[0], -1 + ) + infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view( + position_ids.view(-1).shape[0], -1 + ) + else: + seq_len = infer_state.seq_len + infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) + infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) + infer_state.other_kv_index = infer_state.block_loc[0, infer_state.max_len_in_batch - 1].item() + + if self.pre_seq_len is not None: + if past_key_values is None: + past_key_values = self.get_prompt( + batch_size=batch_size, + device=input_ids.device, + dtype=inputs_embeds.dtype, + ) + if attention_mask is not None: + attention_mask = torch.cat( + [ + attention_mask.new_ones((batch_size, self.pre_seq_len)), + attention_mask, + ], + dim=-1, + ) + if full_attention_mask is None: + if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): + full_attention_mask = get_masks( + self, input_ids, infer_state.cache_manager.past_key_values_length, padding_mask=attention_mask + ) + + # Run encoder. + hidden_states = self.encoder( + hidden_states, + full_attention_mask, + kv_caches=past_key_values, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + infer_state=infer_state, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=shard_config, + ) + + # update indices + infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + infer_state.seq_len += 1 + infer_state.max_len_in_batch += 1 + + return {"hidden_states": hidden_states} + + @staticmethod + def chatglm_encoder_forward( + self: GLMTransformer, + hidden_states, + attention_mask, + kv_caches=None, + use_cache: Optional[bool] = True, + output_hidden_states: Optional[bool] = False, + infer_state: Optional[BatchInferState] = None, + stage_manager: Optional[PipelineStageManager] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ): + hidden_states = hidden_states.transpose(0, 1).contiguous() + + infer_state.decode_layer_id = 0 + start_idx, end_idx = stage_index[0], stage_index[1] + if kv_caches is None: + kv_caches = tuple([None] * (end_idx - start_idx + 1)) + + for idx, kv_cache in zip(range(start_idx, end_idx), kv_caches): + layer = self.layers[idx] + layer_ret = layer( + hidden_states, + attention_mask, + kv_cache=kv_cache, + use_cache=use_cache, + infer_state=infer_state, + ) + infer_state.decode_layer_id += 1 + + hidden_states, _ = layer_ret + + hidden_states = hidden_states.transpose(0, 1).contiguous() + + if self.post_layer_norm and (stage_manager.is_last_stage() or stage_manager.num_stages == 1): + # Final layer norm. + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states + + @staticmethod + def chatglm_glmblock_forward( + self: GLMBlock, + hidden_states, + attention_mask, + kv_cache=None, + use_cache=True, + infer_state: Optional[BatchInferState] = None, + ): + # hidden_states: [s, b, h] + + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + attention_output, kv_cache = self.self_attention( + layernorm_output, + attention_mask, + kv_cache=kv_cache, + use_cache=use_cache, + infer_state=infer_state, + ) + # Residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training) + layernorm_input = residual + layernorm_input + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + # MLP. + mlp_output = self.mlp(layernorm_output) + + # Second residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training) + output = residual + output + return output, kv_cache + + @staticmethod + def chatglm_flash_attn_kvcache_forward( + self: SelfAttention, + hidden_states, + attention_mask, + kv_cache=None, + use_cache=True, + infer_state: Optional[BatchInferState] = None, + ): + assert use_cache is True, "use_cache should be set to True using this chatglm attention" + # hidden_states: original :[sq, b, h] --> this [b, sq, h] + batch_size = hidden_states.shape[0] + hidden_size = hidden_states.shape[-1] + # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] + mixed_x_layer = self.query_key_value(hidden_states) + if self.multi_query_attention: + (query_layer, key_layer, value_layer) = mixed_x_layer.split( + [ + self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + ], + dim=-1, + ) + query_layer = query_layer.view( + query_layer.size()[:-1] + + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) + ) + key_layer = key_layer.view( + key_layer.size()[:-1] + + ( + self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head, + ) + ) + value_layer = value_layer.view( + value_layer.size()[:-1] + + ( + self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head, + ) + ) + + else: + new_tensor_shape = mixed_x_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] + (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) + cos, sin = infer_state.position_cos, infer_state.position_sin + + chatglm2_rotary_emb_fwd( + query_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), cos, sin + ) + if self.multi_query_attention: + chatglm2_rotary_emb_fwd( + key_layer.view(-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head), + cos, + sin, + ) + else: + chatglm2_rotary_emb_fwd( + key_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), + cos, + sin, + ) + + # reshape q k v to [bsz*sql, num_heads, head_dim] 2*1 ,32/2 ,128 + query_layer = query_layer.reshape( + -1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head + ) + key_layer = key_layer.reshape( + -1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head + ) + value_layer = value_layer.reshape( + -1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head + ) + + if infer_state.is_context_stage: + # first token generation: + # copy key and value calculated in current step to memory manager + copy_kv_to_mem_cache( + infer_state.decode_layer_id, + key_layer, + value_layer, + infer_state.context_mem_index, + infer_state.cache_manager, + ) + attn_output = torch.empty_like(query_layer.contiguous().view(-1, self.projection_size)) + + # NOTE: no bug in context attn fwd (del it ) + lightllm_llama2_context_attention_fwd( + query_layer, + key_layer, + value_layer, + attn_output.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), + infer_state.start_loc, + infer_state.seq_len, + infer_state.max_len_in_batch, + ) + + else: + if infer_state.decode_is_contiguous: + # if decode is contiguous, then we copy to key cache and value cache in cache manager directly + cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][ + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] + cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][ + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] + cache_k.copy_(key_layer) + cache_v.copy_(value_layer) + else: + # if decode is not contiguous, use triton kernel to copy key and value cache + # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head + copy_kv_to_mem_cache( + infer_state.decode_layer_id, + key_layer, + value_layer, + infer_state.decode_mem_index, + infer_state.cache_manager, + ) + + # second token and follows + attn_output = torch.empty_like(query_layer.contiguous().view(-1, self.projection_size)) + cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][ + : infer_state.decode_mem_end, :, : + ] + cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][ + : infer_state.decode_mem_end, :, : + ] + + # ================================== + # core attention computation is replaced by triton kernel + # ================================== + Llama2TokenAttentionForwards.token_attn( + query_layer, + cache_k, + cache_v, + attn_output, + infer_state.block_loc, + infer_state.start_loc, + infer_state.seq_len, + infer_state.max_len_in_batch, + infer_state.other_kv_index, + ) + + # ================= + # Output:[b,sq, h] + # ================= + output = self.dense(attn_output).reshape(batch_size, -1, hidden_size) + + return output, kv_cache diff --git a/colossalai/inference/hybridengine/polices/__init__.py b/colossalai/inference/hybridengine/polices/__init__.py index e40ad8143..84dfb5aff 100644 --- a/colossalai/inference/hybridengine/polices/__init__.py +++ b/colossalai/inference/hybridengine/polices/__init__.py @@ -1,4 +1,5 @@ from .bloom import BloomModelInferPolicy +from .chatglm import ChatGLM2InferPolicy from .llama import LlamaModelInferPolicy -__all__ = ["LlamaModelInferPolicy", "BloomModelInferPolicy"] +__all__ = ["LlamaModelInferPolicy", "BloomModelInferPolicy", "ChatGLM2InferPolicy"] diff --git a/colossalai/inference/hybridengine/polices/chatglm.py b/colossalai/inference/hybridengine/polices/chatglm.py new file mode 100644 index 000000000..3e1d94f47 --- /dev/null +++ b/colossalai/inference/hybridengine/polices/chatglm.py @@ -0,0 +1,89 @@ +from typing import List + +import torch.nn as nn + +from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( + ChatGLMForConditionalGeneration, + ChatGLMModel, + GLMBlock, + GLMTransformer, + SelfAttention, +) + +# import colossalai +from colossalai.shardformer.policies.chatglm2 import ChatGLMModelPolicy + +from ..modeling._utils import init_to_get_rotary +from ..modeling.chatglm2 import ChatGLM2InferenceForwards + +try: + HAS_TRITON_RMSNORM = True +except: + print("you should install triton from https://github.com/openai/triton") + HAS_TRITON_RMSNORM = False + + +class ChatGLM2InferPolicy(ChatGLMModelPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = super().module_policy() + self.shard_config._infer() + + model_infer_forward = ChatGLM2InferenceForwards.chatglm_model_forward + method_replacement = {"forward": model_infer_forward} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=ChatGLMModel) + + encoder_infer_forward = ChatGLM2InferenceForwards.chatglm_encoder_forward + method_replacement = {"forward": encoder_infer_forward} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=GLMTransformer + ) + + encoder_layer_infer_forward = ChatGLM2InferenceForwards.chatglm_glmblock_forward + method_replacement = {"forward": encoder_layer_infer_forward} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=GLMBlock) + + attn_infer_forward = ChatGLM2InferenceForwards.chatglm_flash_attn_kvcache_forward + method_replacement = {"forward": attn_infer_forward} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=SelfAttention + ) + if self.shard_config.enable_tensor_parallelism: + policy[GLMBlock].attribute_replacement["self_attention.num_multi_query_groups_per_partition"] = ( + self.model.config.multi_query_group_num // self.shard_config.tensor_parallel_size + ) + # for rmsnorm and others, we need to check the shape + + self.set_pipeline_forward( + model_cls=ChatGLMForConditionalGeneration, + new_forward=ChatGLM2InferenceForwards.chatglm_for_conditional_generation_forward, + policy=policy, + ) + + return policy + + def get_held_layers(self) -> List[nn.Module]: + module = self.model.transformer + stage_manager = self.pipeline_stage_manager + + held_layers = [] + layers_per_stage = self.distribute_layers(module.num_layers, stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.embedding) + held_layers.append(module.output_layer) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.encoder.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + if module.encoder.post_layer_norm: + held_layers.append(module.encoder.final_layernorm) + + # rotary_pos_emb is needed for all stages + held_layers.append(module.rotary_pos_emb) + + return held_layers + + def postprocess(self): + init_to_get_rotary(self.model.transformer) + return self.model diff --git a/tests/test_infer/test_hybrid_chatglm2.py b/tests/test_infer/test_hybrid_chatglm2.py new file mode 100644 index 000000000..d5c5f0dee --- /dev/null +++ b/tests/test_infer/test_hybrid_chatglm2.py @@ -0,0 +1,110 @@ +import pytest +import torch +import torch.distributed as dist +from packaging import version + +import colossalai +from colossalai.inference import CaiInferEngine, ChatGLM2InferPolicy +from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig +from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn + +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") + + +def data_gen(): + input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) + return dict(input_ids=input_ids, attention_mask=attention_mask) + + +inputs = data_gen() +for k, v in inputs.items(): + if torch.is_tensor(v) or "Tensor" in v.__class__.__name__: + new_shape = [1] * v.dim() + new_shape[0] = 16 + inputs[k] = v.to("cuda").repeat(*new_shape) + + +def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): + chatglm_config = ChatGLMConfig( + num_layers=2, + vocab_size=20000, + use_cache=True, + multi_query_attention=True, + multi_query_group_num=2, + num_attention_heads=8, + hidden_size=1024, + ) + model = ChatGLMForConditionalGeneration(chatglm_config) + + engine = CaiInferEngine( + tp_size=tp_size, + pp_size=pp_size, + model=model, + model_policy=ChatGLM2InferPolicy(), + max_output_len=max_output_len, + micro_batch_size=micro_batch_size, + ) + output = engine.inference(inputs) + if dist.get_rank() == 0: + assert len(output[0]) == max_output_len, f"{len(output)}, {max_output_len}" + + +@parameterize("tp_size", [1]) +@parameterize("pp_size", [2]) +@parameterize("max_output_len", [4]) +@parameterize("micro_batch_size", [1]) +@clear_cache_before_run() +def run_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): + pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size) + torch.cuda.empty_cache() + + +@parameterize("tp_size", [2]) +@parameterize("pp_size", [2]) +@parameterize("max_output_len", [4]) +@parameterize("micro_batch_size", [1]) +@clear_cache_before_run() +def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): + pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size) + torch.cuda.empty_cache() + + +@parameterize("tp_size", [2]) +@parameterize("pp_size", [1]) +@parameterize("max_output_len", [2]) +@parameterize("micro_batch_size", [1]) +@clear_cache_before_run() +def run_tp_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): + pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size) + torch.cuda.empty_cache() + + +def check_pipeline_inference(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_pipeline_inference_test() + + +def check_tp_pipeline_inference(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_tp_pipeline_inference_test() + + +def check_tp_inference(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_tp_inference_test() + + +@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_pipeline_inference(): + spawn(check_pipeline_inference, nprocs=2) + spawn(check_tp_pipeline_inference, nprocs=4) + spawn(check_tp_inference, nprocs=2) + + +if __name__ == "__main__": + test_pipeline_inference()