from typing import List, Tuple, Union import rpyc import torch import torch.distributed as dist from torch import nn from transformers import AutoConfig, AutoModelForCausalLM from transformers.models.llama.modeling_llama import LlamaForCausalLM import colossalai from colossalai.accelerator import get_accelerator from colossalai.cluster import ProcessGroupMesh from colossalai.inference.config import InferenceConfig, InputMetaData from colossalai.inference.flash_decoding_utils import FDIntermTensors from colossalai.inference.modeling.policy import ( NoPaddingBaichuanModelInferPolicy, NoPaddingLlamaModelInferPolicy, model_policy_map, ) from colossalai.inference.sampler import search_tokens from colossalai.inference.utils import get_model_size from colossalai.interface import ModelWrapper from colossalai.logging import get_dist_logger from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer.policies.base_policy import Policy PP_AXIS, TP_AXIS = 0, 1 _SUPPORTED_MODELS = { "LlamaForCausalLM": LlamaForCausalLM, "BaichuanForCausalLM": AutoModelForCausalLM, } _SUPPORTED_MODEL_POLICIES = { "NoPaddingLlamaModelInferPolicy": NoPaddingLlamaModelInferPolicy, "NoPaddingBaichuanModelInferPolicy": NoPaddingBaichuanModelInferPolicy, } logger = get_dist_logger(__name__) class rpcWorkerService(rpyc.Service): """ Execute the computation tasks and manage its own kv cache Func with prefix `exposed_` will be invoked by client. """ def exposed_init_dist_env(self, rank, world_size, master_address, master_port): logger.info(f"init process group for rank {rank}") colossalai.launch(rank=rank, world_size=world_size, port=master_port, host=master_address) logger.info(f"init process group done for rank {rank}") def exposed_init_model( self, inference_config_param: dict, model_or_path: Union[nn.Module, str], model_policy_param: str = None ): assert dist.is_initialized(), "invoke init_dist_env first please!" self.inference_config = InferenceConfig.from_rpc_param(inference_config_param) model_policy = _SUPPORTED_MODEL_POLICIES[model_policy_param]() if model_policy_param else None self.dtype = self.inference_config.dtype self.verbose = True self._init_model(model_or_path, model_policy) self._init_fd_tensor() self._init_output_tensor() logger.info(f"init model done for rank {dist.get_rank()}") def exposed_init_cache(self, alloc_shape: Tuple[Tuple[int, ...], Tuple[int, ...]]): """Initialize the physical cache on the device. For each layer of the model, we allocate two tensors for key and value respectively, with shape of [num_blocks, num_kv_heads, block_size, head_size] """ kalloc_shape, valloc_shape = alloc_shape num_layers = self.model_config.num_hidden_layers self.k_cache: List[torch.Tensor] = [] self.v_cache: List[torch.Tensor] = [] for _ in range(num_layers): self.k_cache.append( torch.zeros( kalloc_shape, dtype=self.inference_config.kv_cache_dtype, device=get_accelerator().get_current_device(), ) ) self.v_cache.append( torch.zeros( valloc_shape, dtype=self.inference_config.kv_cache_dtype, device=get_accelerator().get_current_device(), ) ) logger.info("physical cache init over") def exposed_execute_model_forward( self, input_token_ids_param: List[int], input_meta_data_param: dict, generation_config_param: dict ): # prepare the data for model forward input_meta_data = InputMetaData.from_rpc_param(input_meta_data_param) input_meta_data.fd_inter_tensor = self.fd_inter_tensor if input_meta_data.is_prompts: n_tokens = input_meta_data.sequence_lengths.sum().item() else: n_tokens = input_meta_data.batch_size input_token_ids = torch.tensor(input_token_ids_param, dtype=torch.int, device=self.device) # execute the model logits = self.model( input_token_ids, self.output_tensor[:n_tokens], input_meta_data, self.k_cache, self.v_cache, ) # sampler if self.inference_config.pad_input: logits = logits[:, -1, :] next_tokens = search_tokens( generation_config_param, logits, input_meta_data.is_prompts, input_meta_data.batch_token_ids, ) # return the tokens generated to scheduler return next_tokens.tolist() def _init_output_tensor(self): alloc_shape = ( self.inference_config.max_batch_size * (self.inference_config.max_input_len + self.inference_config.max_output_len), self.model_config.hidden_size // self.inference_config.tp_size, ) self.output_tensor = torch.zeros(alloc_shape, dtype=self.dtype, device=self.device) def _init_fd_tensor(self): fd_inter_tensor = FDIntermTensors() if fd_inter_tensor._tensors_initialized: fd_inter_tensor._reset() # For Spec-Dec, process the speculated tokens plus the token in the last step for each seq max_n_tokens = self.inference_config.max_batch_size max_n_tokens *= self.inference_config.max_n_spec_tokens + 1 inference_config = self.inference_config kv_max_split_num = ( inference_config.max_input_len + inference_config.max_output_len + inference_config.block_size - 1 ) // inference_config.block_size head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads fd_inter_tensor.initialize( max_batch_size=max_n_tokens, num_attn_heads=self.model_config.num_attention_heads // self.inference_config.tp_size, kv_max_split_num=kv_max_split_num, head_dim=head_dim, dtype=self.dtype, device=get_accelerator().get_current_device(), ) self.fd_inter_tensor = fd_inter_tensor def _init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None): """ Shard model or/and Load weight Shard model: When we set tp_size > 1, we will shard the model by given model_policy. Load Weight: If we pass a local model path, we will load the model weight by checkpoint_io. If it is a remote-transformer url, we will use `AutoModel.from_pretrained` api of transformers lib Args: model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format. model_policy (Policy): the policy to replace the model """ if isinstance(model_or_path, str): # is_local = os.path.isdir(model_or_path) try: hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True) arch = getattr(hf_config, "architectures")[0] # NOTE(lry89757) Currently we load the model using transformers-api, # but we will use lazy tensor and checkpoint io to accelerate # the model load process in the future. model = _SUPPORTED_MODELS[arch].from_pretrained(model_or_path, trust_remote_code=True) # if is_local: # model = _SUPPORTED_MODELS[arch](hf_config) # else: # # load the real checkpoint # model = _SUPPORTED_MODELS[arch].from_pretrained(model_or_path, trust_remote_code=True) except Exception as e: logger.error( f"An exception occurred during loading model: {e}, model should be loaded by transformers\n" ) else: model = model_or_path self.model_config = model.config torch.cuda.empty_cache() init_gpu_memory = torch.cuda.mem_get_info()[0] self.device = get_accelerator().get_current_device() torch.cuda.set_device(self.device) if self.verbose: logger.info(f"the device is {self.device}") model = model.to(dtype=self.dtype, non_blocking=False).eval() if self.verbose: logger.info( f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}" ) if model_policy is None: if self.inference_config.pad_input: model_type = "padding_" + self.model_config.model_type else: model_type = "nopadding_" + self.model_config.model_type model_policy = model_policy_map[model_type]() pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size) tp_group = pg_mesh.get_group_along_axis(TP_AXIS) self.model = self._shardformer( model, model_policy, None, tp_group=tp_group, ) self.model = ModelWrapper(model).to(device=get_accelerator().get_current_device()) if self.verbose: logger.info( f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}" ) # NOTE(lry89757) Deprecated currently, will reused when introduce lazy tensor # if isinstance(model_or_path, str) and is_local: # from colossalai.inference.core.plugin import InferCheckpoint_io # cpt_io = InferCheckpoint_io() # if_has_index_file, model_index_file = has_index_file(model_or_path) # assert if_has_index_file, "the model path is invalid" # cpt_io.load_model(self.model, model_index_file) free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() peak_memory = init_gpu_memory - free_gpu_memory if self.verbose: logger.info( f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB" ) def _shardformer( self, model: nn.Module, model_policy: Policy, stage_manager: PipelineStageManager = None, tp_group: ProcessGroupMesh = None, ) -> nn.Module: """ Initialize ShardConfig and replace the model with shardformer. Args: model (nn.Module): Path or nn.Module of this model. model_policy (Policy): The policy to shardformer model which is determined by the model type. stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None. tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None. Returns: nn.Module: The model optimized by Shardformer. """ shardconfig = ShardConfig( tensor_parallel_process_group=tp_group, pipeline_stage_manager=stage_manager, enable_tensor_parallelism=(self.inference_config.tp_size > 1), enable_fused_normalization=False, enable_all_optimization=False, enable_flash_attention=False, enable_jit_fused=False, enable_sequence_parallelism=False, ) shardformer = ShardFormer(shard_config=shardconfig) shard_model, _ = shardformer.optimize(model, model_policy) return shard_model def exposed_compute_only_for_test(self): dist_rank = dist.get_rank() # Dummy data for each worker data = torch.tensor([dist_rank], dtype=torch.float).cuda(dist_rank) dist.barrier() # Perform distributed all_reduce dist.all_reduce(data, op=dist.ReduceOp.SUM) dist.barrier() logger.info(f"Worker rank {dist_rank}: Sum after all_reduce: {data.item()}") return data.item()