mirror of https://github.com/hpcaitech/ColossalAI
301 lines
12 KiB
Python
301 lines
12 KiB
Python
import os
|
|
from typing import List, Tuple, Union
|
|
|
|
import rpyc
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch import nn
|
|
from transformers import AutoConfig, AutoModelForCausalLM
|
|
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
|
|
|
import colossalai
|
|
from colossalai.accelerator import get_accelerator
|
|
from colossalai.cluster import ProcessGroupMesh
|
|
from colossalai.inference.config import InferenceConfig, InputMetaData
|
|
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
|
from colossalai.inference.modeling.policy import (
|
|
NoPaddingBaichuanModelInferPolicy,
|
|
NoPaddingLlamaModelInferPolicy,
|
|
model_policy_map,
|
|
)
|
|
from colossalai.inference.sampler import search_tokens
|
|
from colossalai.inference.utils import get_model_size, has_index_file
|
|
from colossalai.interface import ModelWrapper
|
|
from colossalai.logging import get_dist_logger
|
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
|
from colossalai.shardformer import ShardConfig, ShardFormer
|
|
from colossalai.shardformer.policies.base_policy import Policy
|
|
|
|
PP_AXIS, TP_AXIS = 0, 1
|
|
|
|
_SUPPORTED_MODELS = {
|
|
"LlamaForCausalLM": LlamaForCausalLM,
|
|
"BaichuanForCausalLM": AutoModelForCausalLM,
|
|
}
|
|
|
|
_SUPPORTED_MODEL_POLICIES = {
|
|
"NoPaddingLlamaModelInferPolicy": NoPaddingLlamaModelInferPolicy,
|
|
"NoPaddingBaichuanModelInferPolicy": NoPaddingBaichuanModelInferPolicy,
|
|
}
|
|
|
|
logger = get_dist_logger(__name__)
|
|
|
|
|
|
class rpcWorkerService(rpyc.Service):
|
|
|
|
"""
|
|
Execute the computation tasks and manage its own kv cache
|
|
|
|
Func with prefix `exposed_` will be invoked by client.
|
|
"""
|
|
|
|
def exposed_init_dist_env(self, rank, world_size, master_address, master_port):
|
|
logger.info(f"init process group for rank {rank}")
|
|
colossalai.launch(rank=rank, world_size=world_size, port=master_port, host=master_address)
|
|
logger.info(f"init process group done for rank {rank}")
|
|
|
|
def exposed_init_model(
|
|
self, inference_config_param: dict, model_or_path: Union[nn.Module, str], model_policy_param: str = None
|
|
):
|
|
assert dist.is_initialized(), "invoke init_dist_env first please!"
|
|
|
|
self.inference_config = InferenceConfig.from_rpc_param(inference_config_param)
|
|
model_policy = _SUPPORTED_MODEL_POLICIES[model_policy_param]() if model_policy_param else None
|
|
|
|
self.dtype = self.inference_config.dtype
|
|
self.verbose = True
|
|
|
|
self._init_model(model_or_path, model_policy)
|
|
self._init_fd_tensor()
|
|
self._init_output_tensor()
|
|
logger.info(f"init model done for rank {dist.get_rank()}")
|
|
|
|
def exposed_init_cache(self, alloc_shape: Tuple[Tuple[int, ...], Tuple[int, ...]]):
|
|
"""Initialize the physical cache on the device.
|
|
|
|
For each layer of the model, we allocate two tensors for key and value respectively,
|
|
with shape of [num_blocks, num_kv_heads, block_size, head_size]
|
|
"""
|
|
kalloc_shape, valloc_shape = alloc_shape
|
|
num_layers = self.model_config.num_hidden_layers
|
|
|
|
self.k_cache: List[torch.Tensor] = []
|
|
self.v_cache: List[torch.Tensor] = []
|
|
for _ in range(num_layers):
|
|
self.k_cache.append(
|
|
torch.zeros(
|
|
kalloc_shape,
|
|
dtype=self.inference_config.kv_cache_dtype,
|
|
device=get_accelerator().get_current_device(),
|
|
)
|
|
)
|
|
self.v_cache.append(
|
|
torch.zeros(
|
|
valloc_shape,
|
|
dtype=self.inference_config.kv_cache_dtype,
|
|
device=get_accelerator().get_current_device(),
|
|
)
|
|
)
|
|
logger.info("physical cache init over")
|
|
|
|
def exposed_execute_model_forward(self, input_token_ids_param: List[int], input_meta_data_param: dict):
|
|
# prepare the data for model forward
|
|
input_meta_data = InputMetaData.from_rpc_param(input_meta_data_param)
|
|
input_meta_data.fd_inter_tensor = self.fd_inter_tensor
|
|
if input_meta_data.is_prompts:
|
|
n_tokens = input_meta_data.sequence_lengths.sum().item()
|
|
else:
|
|
n_tokens = input_meta_data.batch_size
|
|
input_token_ids = torch.tensor(input_token_ids_param, dtype=torch.int, device=self.device)
|
|
|
|
# execute the model
|
|
logits = self.model(
|
|
input_token_ids,
|
|
self.output_tensor[:n_tokens],
|
|
input_meta_data,
|
|
self.k_cache,
|
|
self.v_cache,
|
|
)
|
|
|
|
# sampler
|
|
if self.inference_config.pad_input:
|
|
logits = logits[:, -1, :]
|
|
next_tokens = search_tokens(
|
|
self.inference_config.to_generation_config(self.model_config),
|
|
logits,
|
|
input_meta_data.is_prompts,
|
|
input_meta_data.batch_token_ids,
|
|
)
|
|
|
|
# return the tokens generated to scheduler
|
|
return next_tokens.tolist()
|
|
|
|
def _init_output_tensor(self):
|
|
alloc_shape = (
|
|
self.inference_config.max_batch_size
|
|
* (self.inference_config.max_input_len + self.inference_config.max_output_len),
|
|
self.model_config.hidden_size // self.inference_config.tp_size,
|
|
)
|
|
self.output_tensor = torch.zeros(alloc_shape, dtype=self.dtype, device=self.device)
|
|
|
|
def _init_fd_tensor(self):
|
|
fd_inter_tensor = FDIntermTensors()
|
|
|
|
if fd_inter_tensor._tensors_initialized:
|
|
fd_inter_tensor._reset()
|
|
|
|
# For Spec-Dec, process the speculated tokens plus the token in the last step for each seq
|
|
max_n_tokens = self.inference_config.max_batch_size
|
|
max_n_tokens *= self.inference_config.max_n_spec_tokens + 1
|
|
|
|
inference_config = self.inference_config
|
|
kv_max_split_num = (
|
|
inference_config.max_input_len + inference_config.max_output_len + inference_config.block_size - 1
|
|
) // inference_config.block_size
|
|
head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads
|
|
|
|
fd_inter_tensor.initialize(
|
|
max_batch_size=max_n_tokens,
|
|
num_attn_heads=self.model_config.num_attention_heads // self.inference_config.tp_size,
|
|
kv_max_split_num=kv_max_split_num,
|
|
head_dim=head_dim,
|
|
dtype=self.dtype,
|
|
device=get_accelerator().get_current_device(),
|
|
)
|
|
|
|
self.fd_inter_tensor = fd_inter_tensor
|
|
|
|
def _init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None):
|
|
"""
|
|
Shard model or/and Load weight
|
|
|
|
Shard model: When we set tp_size > 1, we will shard the model by given model_policy.
|
|
Load Weight: If we pass a local model path, we will load the model weight by checkpoint_io. If it is a remote-transformer url, we will use `AutoModel.from_pretrained` api of transformers lib
|
|
|
|
Args:
|
|
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
|
|
model_policy (Policy): the policy to replace the model
|
|
"""
|
|
|
|
if isinstance(model_or_path, str):
|
|
is_local = os.path.isdir(model_or_path)
|
|
try:
|
|
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
|
|
arch = getattr(hf_config, "architectures")[0]
|
|
if is_local:
|
|
model = _SUPPORTED_MODELS[arch](hf_config)
|
|
else:
|
|
# load the real checkpoint
|
|
model = _SUPPORTED_MODELS[arch].from_pretrained(model_or_path, trust_remote_code=True)
|
|
except Exception as e:
|
|
logger.error(
|
|
f"An exception occurred during loading model: {e}, model should be loaded by transformers\n"
|
|
)
|
|
else:
|
|
model = model_or_path
|
|
|
|
self.model_config = model.config
|
|
|
|
torch.cuda.empty_cache()
|
|
init_gpu_memory = torch.cuda.mem_get_info()[0]
|
|
|
|
self.device = get_accelerator().get_current_device()
|
|
torch.cuda.set_device(self.device)
|
|
if self.verbose:
|
|
logger.info(f"the device is {self.device}")
|
|
|
|
model = model.to(dtype=self.dtype, non_blocking=False).eval()
|
|
|
|
if self.verbose:
|
|
logger.info(
|
|
f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}"
|
|
)
|
|
|
|
if model_policy is None:
|
|
if self.inference_config.pad_input:
|
|
model_type = "padding_" + self.model_config.model_type
|
|
else:
|
|
model_type = "nopadding_" + self.model_config.model_type
|
|
model_policy = model_policy_map[model_type]()
|
|
|
|
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
|
|
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
|
|
|
|
self.model = self._shardformer(
|
|
model,
|
|
model_policy,
|
|
None,
|
|
tp_group=tp_group,
|
|
)
|
|
|
|
self.model = ModelWrapper(model).to(device=get_accelerator().get_current_device())
|
|
|
|
if self.verbose:
|
|
logger.info(
|
|
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
|
|
)
|
|
|
|
if isinstance(model_or_path, str) and is_local:
|
|
from colossalai.inference.core.plugin import InferCheckpoint_io
|
|
|
|
cpt_io = InferCheckpoint_io()
|
|
if_has_index_file, model_index_file = has_index_file(model_or_path)
|
|
assert if_has_index_file, "the model path is invalid"
|
|
cpt_io.load_model(self.model, model_index_file)
|
|
|
|
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
|
|
peak_memory = init_gpu_memory - free_gpu_memory
|
|
if self.verbose:
|
|
logger.info(
|
|
f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB"
|
|
)
|
|
|
|
def _shardformer(
|
|
self,
|
|
model: nn.Module,
|
|
model_policy: Policy,
|
|
stage_manager: PipelineStageManager = None,
|
|
tp_group: ProcessGroupMesh = None,
|
|
) -> nn.Module:
|
|
"""
|
|
Initialize ShardConfig and replace the model with shardformer.
|
|
|
|
Args:
|
|
model (nn.Module): Path or nn.Module of this model.
|
|
model_policy (Policy): The policy to shardformer model which is determined by the model type.
|
|
stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None.
|
|
tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None.
|
|
|
|
Returns:
|
|
nn.Module: The model optimized by Shardformer.
|
|
"""
|
|
|
|
shardconfig = ShardConfig(
|
|
tensor_parallel_process_group=tp_group,
|
|
pipeline_stage_manager=stage_manager,
|
|
enable_tensor_parallelism=(self.inference_config.tp_size > 1),
|
|
enable_fused_normalization=False,
|
|
enable_all_optimization=False,
|
|
enable_flash_attention=False,
|
|
enable_jit_fused=False,
|
|
enable_sequence_parallelism=False,
|
|
)
|
|
shardformer = ShardFormer(shard_config=shardconfig)
|
|
shard_model, _ = shardformer.optimize(model, model_policy)
|
|
return shard_model
|
|
|
|
def exposed_compute_only_for_test(self):
|
|
dist_rank = dist.get_rank()
|
|
|
|
# Dummy data for each worker
|
|
data = torch.tensor([dist_rank], dtype=torch.float).cuda(dist_rank)
|
|
dist.barrier()
|
|
|
|
# Perform distributed all_reduce
|
|
dist.all_reduce(data, op=dist.ReduceOp.SUM)
|
|
|
|
dist.barrier()
|
|
logger.info(f"Worker rank {dist_rank}: Sum after all_reduce: {data.item()}")
|
|
|
|
return data.item()
|