mirror of https://github.com/hpcaitech/ColossalAI
Browse Source
* rpc support source * kv cache logical/physical disaggregation * sampler refactor * colossalai launch built in * Unitest * Rpyc support --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/5716/head
Runyu Lu
6 months ago
committed by
GitHub
15 changed files with 1032 additions and 63 deletions
@ -0,0 +1,291 @@
|
||||
import asyncio |
||||
from itertools import count |
||||
from time import sleep |
||||
from typing import List, Tuple, Union |
||||
|
||||
import rpyc |
||||
import torch |
||||
import torch.nn as nn |
||||
from rpyc.utils.server import ThreadedServer |
||||
from torch import multiprocessing as mp |
||||
from transformers import AutoConfig, PreTrainedTokenizer, PreTrainedTokenizerFast |
||||
from transformers.configuration_utils import PretrainedConfig |
||||
|
||||
from colossalai.inference.batch_bucket import BatchBucket |
||||
from colossalai.inference.config import InferenceConfig, InputMetaData |
||||
from colossalai.inference.executor.rpc_worker import rpcWorkerService |
||||
from colossalai.inference.utils import find_available_ports |
||||
from colossalai.logging import get_dist_logger |
||||
from colossalai.shardformer.policies.base_policy import Policy |
||||
|
||||
from .engine import InferenceEngine |
||||
from .request_handler import RPCRequestHandler |
||||
|
||||
__all__ = ["RPCInferenceEngine"] |
||||
|
||||
|
||||
def run_server(host, port, event: mp.Event = None): |
||||
server = ThreadedServer( |
||||
rpcWorkerService, port=port, protocol_config={"allow_public_attrs": True, "allow_all_attrs": True} |
||||
) |
||||
if event: |
||||
event.set() |
||||
server.start() |
||||
|
||||
|
||||
class RPCInferenceEngine(InferenceEngine): |
||||
|
||||
""" |
||||
InferenceEngine which manages the inference process.. |
||||
|
||||
NOTE This `RPCInferenceEngine` is designed for multiple-card/online serving. |
||||
Original `InferenceEngine` is designed for single card and offline service, though it supports multi-card offline inference. |
||||
|
||||
Args: |
||||
model_or_path (nn.Module or str): Path or nn.Module of this model, Currently we don't support `nn.Module` Format |
||||
tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use. |
||||
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference. |
||||
verbose (bool): Determine whether or not to log the generation process. |
||||
model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided. |
||||
""" |
||||
|
||||
def __init__( |
||||
self, |
||||
model_or_path: Union[nn.Module, str], |
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], |
||||
inference_config: InferenceConfig, |
||||
verbose: bool = False, |
||||
model_policy: Policy = None, |
||||
) -> None: |
||||
""" |
||||
If you input a real model loaded by transformers, the init will take quite a long time |
||||
Currently we don't support model(nn.Module) format as the param. |
||||
""" |
||||
|
||||
torch.multiprocessing.set_start_method("spawn", force=True) |
||||
|
||||
self.inference_config = inference_config |
||||
self.tokenizer = tokenizer |
||||
self.tokenizer.pad_token = self.tokenizer.eos_token |
||||
|
||||
self.verbose = verbose |
||||
self.logger = get_dist_logger(__name__) |
||||
|
||||
try: |
||||
if isinstance(model_or_path, str): |
||||
self.model_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True) |
||||
elif isinstance(model_or_path, nn.Module): |
||||
self.logger.error( |
||||
f"An exception occurred during loading model Config: For {__class__.__name__}, we don't support param like nn.Module currently\n" |
||||
) |
||||
# self.model_config = model_or_path.config |
||||
else: |
||||
self.logger.error( |
||||
f"An exception occurred during loading model Config: Please pass right param for {__class__.__name__}\n" |
||||
) |
||||
except Exception as e: |
||||
self.logger.error( |
||||
f"An exception occurred during loading model Config: {e}, The path should be transformers-like\n" |
||||
) |
||||
self.generation_config = inference_config.to_generation_config(self.model_config) |
||||
|
||||
self.tp_size = inference_config.tp_size |
||||
self.events = [mp.Event() for _ in range(self.tp_size)] |
||||
|
||||
# This operation will init the dist env and models |
||||
self.workers: List[rpcWorkerService] = [] |
||||
self.init_workers() |
||||
|
||||
asyncio.run(self.init_model(model_or_path, model_policy)) |
||||
|
||||
# init the scheduler and logic block manager |
||||
self.request_handler = self.init_scheduler(self.inference_config, self.model_config) |
||||
|
||||
# init the physical cache |
||||
alloc_shape = self.request_handler.cache_manager.get_physical_cache_shape() |
||||
self.init_device_cache(alloc_shape) |
||||
|
||||
self.use_cuda_graph = self.inference_config.use_cuda_graph |
||||
self.high_precision = inference_config.high_precision |
||||
self.dtype = inference_config.dtype |
||||
|
||||
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec` |
||||
self.use_spec_dec = False |
||||
self.drafter_model = None |
||||
self.drafter = None |
||||
self.use_glide = False |
||||
self.n_spec_tokens = self.inference_config.max_n_spec_tokens |
||||
|
||||
self.counter = count() |
||||
self._verify_args() |
||||
|
||||
self.logger.info("engine init over ") |
||||
|
||||
def _verify_args(self) -> None: |
||||
"""Verify the input args""" |
||||
if not isinstance(self.inference_config, InferenceConfig): |
||||
raise TypeError("Invalid type of inference config provided.") |
||||
if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)): |
||||
raise TypeError( |
||||
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}" |
||||
) |
||||
|
||||
def init_workers(self): |
||||
rpc_ports = find_available_ports(self.tp_size) |
||||
self.worker_processes = [] |
||||
# mp.set_start_method('spawn') |
||||
for event, rpc_port in zip(self.events, rpc_ports): |
||||
p = mp.Process(target=run_server, args=("localhost", rpc_port, event)) |
||||
p.start() |
||||
self.worker_processes.append(p) |
||||
self.logger.info(f"Starting RPC Worker on localhost:{rpc_port}...") |
||||
|
||||
# Wait for all servers to start |
||||
for event in self.events: |
||||
event.wait() |
||||
event.clear() |
||||
|
||||
sleep(0.05) |
||||
|
||||
self.logger.info(f"init rpc server done.") |
||||
|
||||
for rpc_port in rpc_ports: |
||||
try: |
||||
conn = rpyc.connect( |
||||
"localhost", |
||||
rpc_port, |
||||
config={"allow_pickle": True, "allow_public_attrs": True, "allow_all_attrs": True}, |
||||
) |
||||
self.workers.append(conn.root) |
||||
except: |
||||
raise Exception("conn error!") |
||||
self.logger.info(f"Build RPC Connection Success! Begin to load model...") |
||||
asyncio.run(self.init_worker_env()) |
||||
self.logger.info(f"init dist env over") |
||||
|
||||
async def async_parallel_wrapper(self, f, *args, **kwargs): |
||||
async_res = rpyc.async_(f)(*args, **kwargs) |
||||
await asyncio.to_thread(async_res.wait) |
||||
assert async_res.ready |
||||
return async_res.value |
||||
|
||||
async def init_worker_env(self): |
||||
assert len(self.workers) == self.tp_size, "init workers first" |
||||
|
||||
dist_group_port = find_available_ports(1)[0] |
||||
init_tasks = [ |
||||
self.async_parallel_wrapper( |
||||
worker.init_dist_env, rank, self.inference_config.tp_size, "127.0.0.1", dist_group_port |
||||
) |
||||
for rank, worker in enumerate(self.workers) |
||||
] |
||||
|
||||
await asyncio.gather(*init_tasks) |
||||
|
||||
async def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None): |
||||
assert len(self.workers) == self.tp_size, "init workers first" |
||||
|
||||
inference_config_param = self.inference_config.to_rpc_param() |
||||
model_path = model_or_path |
||||
model_policy_param = model_policy.to_rpc_param() if model_policy else None |
||||
|
||||
init_tasks = [ |
||||
self.async_parallel_wrapper(worker.init_model, inference_config_param, model_path, model_policy_param) |
||||
for rank, worker in enumerate(self.workers) |
||||
] |
||||
|
||||
await asyncio.gather(*init_tasks) |
||||
|
||||
def init_scheduler(self, inference_config: InferenceConfig, model_config: PretrainedConfig) -> RPCRequestHandler: |
||||
return RPCRequestHandler(inference_config, model_config) |
||||
|
||||
async def _init_device_cache(self, alloc_shape: Tuple[int, int, int, int]): |
||||
assert len(self.workers) == self.tp_size, "init workers first" |
||||
|
||||
init_tasks = [self.async_parallel_wrapper(worker.init_cache, alloc_shape) for worker in self.workers] |
||||
|
||||
await asyncio.gather(*init_tasks) |
||||
|
||||
def init_device_cache(self, alloc_shape: Tuple[Tuple[int, ...], Tuple[int, ...]]): |
||||
asyncio.run(self._init_device_cache(alloc_shape)) |
||||
|
||||
def prepare_input(self, batch: BatchBucket) -> Tuple[List[int], InputMetaData]: |
||||
input_ids = batch.get_1D_inputs() |
||||
sequence_lengths = batch.get_sequence_lengths() |
||||
|
||||
if batch.is_prompts: |
||||
n_tokens = sequence_lengths.sum().item() |
||||
else: |
||||
n_tokens = batch.current_batch_size |
||||
if batch.use_spec_dec: |
||||
n_tokens = batch.num_tokens_to_verify + 1 |
||||
assert n_tokens == input_ids.size(0) |
||||
n_tokens = n_tokens * batch.current_batch_size |
||||
|
||||
batch_token_ids = None |
||||
config_dict = self.generation_config.to_dict() |
||||
# process repetition_penalty, no_repeat_ngram_size |
||||
for type in ["repetition_penalty", "no_repeat_ngram_size"]: |
||||
if type in config_dict and config_dict[type] is not None: |
||||
batch_token_ids = batch.batch_token_ids |
||||
|
||||
# only when we have the graph for specific decoding batch size can we use the cuda graph for inference |
||||
use_cuda_graph = False |
||||
if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys(): |
||||
use_cuda_graph = True |
||||
|
||||
input_meta_data = InputMetaData( |
||||
block_tables=batch.get_block_table_tensor(), |
||||
sequence_lengths=sequence_lengths, |
||||
fd_inter_tensor=None, |
||||
batch_size=batch.current_batch_size, |
||||
is_prompts=batch.is_prompts, |
||||
use_cuda_kernel=self.inference_config.use_cuda_kernel, |
||||
use_cuda_graph=use_cuda_graph, |
||||
high_precision=self.high_precision, |
||||
kv_seq_len=sequence_lengths.max().item(), |
||||
head_dim=batch.head_dim, |
||||
dtype=batch.dtype, |
||||
use_spec_dec=batch.use_spec_dec, |
||||
num_tokens_to_verify=batch.num_tokens_to_verify, |
||||
batch_token_ids=batch_token_ids, |
||||
) |
||||
|
||||
return input_ids.tolist(), input_meta_data |
||||
|
||||
async def step_(self, input_token_ids, input_meta_data: InputMetaData): |
||||
assert len(self.workers) == self.tp_size, "init workers first" |
||||
|
||||
init_tasks = [ |
||||
self.async_parallel_wrapper(worker.execute_model_forward, input_token_ids, input_meta_data.to_rpc_param()) |
||||
for worker in self.workers |
||||
] |
||||
ret = await asyncio.gather(*init_tasks) |
||||
|
||||
return ret[0] |
||||
|
||||
def step(self) -> List[str]: |
||||
batch = self.request_handler.schedule() |
||||
|
||||
input_token_ids, input_meta_data = self.prepare_input(batch) |
||||
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported. |
||||
next_tokens = asyncio.run(self.step_(input_token_ids, input_meta_data)) |
||||
|
||||
# update the request_handler |
||||
next_tokens = torch.tensor(next_tokens, dtype=torch.int) |
||||
self.request_handler.append_next_tokens(next_tokens) |
||||
finished_sequences = self.request_handler.update() |
||||
return finished_sequences |
||||
|
||||
def kill_workers(self): |
||||
""" |
||||
I don't find a good way to implicit invoke self.kill_workers |
||||
""" |
||||
assert len(self.workers) != 0 |
||||
for proc in self.worker_processes: |
||||
proc.kill() |
||||
proc.join() |
||||
self.logger.info(f"worker killed, serving end") |
||||
|
||||
def __del__(self): |
||||
self.kill_workers() |
@ -0,0 +1,300 @@
|
||||
import os |
||||
from typing import List, Tuple, Union |
||||
|
||||
import rpyc |
||||
import torch |
||||
import torch.distributed as dist |
||||
from torch import nn |
||||
from transformers import AutoConfig, AutoModelForCausalLM |
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM |
||||
|
||||
import colossalai |
||||
from colossalai.accelerator import get_accelerator |
||||
from colossalai.cluster import ProcessGroupMesh |
||||
from colossalai.inference.config import InferenceConfig, InputMetaData |
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors |
||||
from colossalai.inference.modeling.policy import ( |
||||
NoPaddingBaichuanModelInferPolicy, |
||||
NoPaddingLlamaModelInferPolicy, |
||||
model_policy_map, |
||||
) |
||||
from colossalai.inference.sampler import search_tokens |
||||
from colossalai.inference.utils import get_model_size, has_index_file |
||||
from colossalai.interface import ModelWrapper |
||||
from colossalai.logging import get_dist_logger |
||||
from colossalai.pipeline.stage_manager import PipelineStageManager |
||||
from colossalai.shardformer import ShardConfig, ShardFormer |
||||
from colossalai.shardformer.policies.base_policy import Policy |
||||
|
||||
PP_AXIS, TP_AXIS = 0, 1 |
||||
|
||||
_SUPPORTED_MODELS = { |
||||
"LlamaForCausalLM": LlamaForCausalLM, |
||||
"BaichuanForCausalLM": AutoModelForCausalLM, |
||||
} |
||||
|
||||
_SUPPORTED_MODEL_POLICIES = { |
||||
"NoPaddingLlamaModelInferPolicy": NoPaddingLlamaModelInferPolicy, |
||||
"NoPaddingBaichuanModelInferPolicy": NoPaddingBaichuanModelInferPolicy, |
||||
} |
||||
|
||||
logger = get_dist_logger(__name__) |
||||
|
||||
|
||||
class rpcWorkerService(rpyc.Service): |
||||
|
||||
""" |
||||
Execute the computation tasks and manage its own kv cache |
||||
|
||||
Func with prefix `exposed_` will be invoked by client. |
||||
""" |
||||
|
||||
def exposed_init_dist_env(self, rank, world_size, master_address, master_port): |
||||
logger.info(f"init process group for rank {rank}") |
||||
colossalai.launch(rank=rank, world_size=world_size, port=master_port, host=master_address) |
||||
logger.info(f"init process group done for rank {rank}") |
||||
|
||||
def exposed_init_model( |
||||
self, inference_config_param: dict, model_or_path: Union[nn.Module, str], model_policy_param: str = None |
||||
): |
||||
assert dist.is_initialized(), "invoke init_dist_env first please!" |
||||
|
||||
self.inference_config = InferenceConfig.from_rpc_param(inference_config_param) |
||||
model_policy = _SUPPORTED_MODEL_POLICIES[model_policy_param]() if model_policy_param else None |
||||
|
||||
self.dtype = self.inference_config.dtype |
||||
self.verbose = True |
||||
|
||||
self._init_model(model_or_path, model_policy) |
||||
self._init_fd_tensor() |
||||
self._init_output_tensor() |
||||
logger.info(f"init model done for rank {dist.get_rank()}") |
||||
|
||||
def exposed_init_cache(self, alloc_shape: Tuple[Tuple[int, ...], Tuple[int, ...]]): |
||||
"""Initialize the physical cache on the device. |
||||
|
||||
For each layer of the model, we allocate two tensors for key and value respectively, |
||||
with shape of [num_blocks, num_kv_heads, block_size, head_size] |
||||
""" |
||||
kalloc_shape, valloc_shape = alloc_shape |
||||
num_layers = self.model_config.num_hidden_layers |
||||
|
||||
self.k_cache: List[torch.Tensor] = [] |
||||
self.v_cache: List[torch.Tensor] = [] |
||||
for _ in range(num_layers): |
||||
self.k_cache.append( |
||||
torch.zeros( |
||||
kalloc_shape, |
||||
dtype=self.inference_config.kv_cache_dtype, |
||||
device=get_accelerator().get_current_device(), |
||||
) |
||||
) |
||||
self.v_cache.append( |
||||
torch.zeros( |
||||
valloc_shape, |
||||
dtype=self.inference_config.kv_cache_dtype, |
||||
device=get_accelerator().get_current_device(), |
||||
) |
||||
) |
||||
logger.info("physical cache init over") |
||||
|
||||
def exposed_execute_model_forward(self, input_token_ids_param: List[int], input_meta_data_param: dict): |
||||
# prepare the data for model forward |
||||
input_meta_data = InputMetaData.from_rpc_param(input_meta_data_param) |
||||
input_meta_data.fd_inter_tensor = self.fd_inter_tensor |
||||
if input_meta_data.is_prompts: |
||||
n_tokens = input_meta_data.sequence_lengths.sum().item() |
||||
else: |
||||
n_tokens = input_meta_data.batch_size |
||||
input_token_ids = torch.tensor(input_token_ids_param, dtype=torch.int, device=self.device) |
||||
|
||||
# execute the model |
||||
logits = self.model( |
||||
input_token_ids, |
||||
self.output_tensor[:n_tokens], |
||||
input_meta_data, |
||||
self.k_cache, |
||||
self.v_cache, |
||||
) |
||||
|
||||
# sampler |
||||
if self.inference_config.pad_input: |
||||
logits = logits[:, -1, :] |
||||
next_tokens = search_tokens( |
||||
self.inference_config.to_generation_config(self.model_config), |
||||
logits, |
||||
input_meta_data.is_prompts, |
||||
input_meta_data.batch_token_ids, |
||||
) |
||||
|
||||
# return the tokens generated to scheduler |
||||
return next_tokens.tolist() |
||||
|
||||
def _init_output_tensor(self): |
||||
alloc_shape = ( |
||||
self.inference_config.max_batch_size |
||||
* (self.inference_config.max_input_len + self.inference_config.max_output_len), |
||||
self.model_config.hidden_size // self.inference_config.tp_size, |
||||
) |
||||
self.output_tensor = torch.zeros(alloc_shape, dtype=self.dtype, device=self.device) |
||||
|
||||
def _init_fd_tensor(self): |
||||
fd_inter_tensor = FDIntermTensors() |
||||
|
||||
if fd_inter_tensor._tensors_initialized: |
||||
fd_inter_tensor._reset() |
||||
|
||||
# For Spec-Dec, process the speculated tokens plus the token in the last step for each seq |
||||
max_n_tokens = self.inference_config.max_batch_size |
||||
max_n_tokens *= self.inference_config.max_n_spec_tokens + 1 |
||||
|
||||
inference_config = self.inference_config |
||||
kv_max_split_num = ( |
||||
inference_config.max_input_len + inference_config.max_output_len + inference_config.block_size - 1 |
||||
) // inference_config.block_size |
||||
head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads |
||||
|
||||
fd_inter_tensor.initialize( |
||||
max_batch_size=max_n_tokens, |
||||
num_attn_heads=self.model_config.num_attention_heads // self.inference_config.tp_size, |
||||
kv_max_split_num=kv_max_split_num, |
||||
head_dim=head_dim, |
||||
dtype=self.dtype, |
||||
device=get_accelerator().get_current_device(), |
||||
) |
||||
|
||||
self.fd_inter_tensor = fd_inter_tensor |
||||
|
||||
def _init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None): |
||||
""" |
||||
Shard model or/and Load weight |
||||
|
||||
Shard model: When we set tp_size > 1, we will shard the model by given model_policy. |
||||
Load Weight: If we pass a local model path, we will load the model weight by checkpoint_io. If it is a remote-transformer url, we will use `AutoModel.from_pretrained` api of transformers lib |
||||
|
||||
Args: |
||||
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format. |
||||
model_policy (Policy): the policy to replace the model |
||||
""" |
||||
|
||||
if isinstance(model_or_path, str): |
||||
is_local = os.path.isdir(model_or_path) |
||||
try: |
||||
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True) |
||||
arch = getattr(hf_config, "architectures")[0] |
||||
if is_local: |
||||
model = _SUPPORTED_MODELS[arch](hf_config) |
||||
else: |
||||
# load the real checkpoint |
||||
model = _SUPPORTED_MODELS[arch].from_pretrained(model_or_path, trust_remote_code=True) |
||||
except Exception as e: |
||||
logger.error( |
||||
f"An exception occurred during loading model: {e}, model should be loaded by transformers\n" |
||||
) |
||||
else: |
||||
model = model_or_path |
||||
|
||||
self.model_config = model.config |
||||
|
||||
torch.cuda.empty_cache() |
||||
init_gpu_memory = torch.cuda.mem_get_info()[0] |
||||
|
||||
self.device = get_accelerator().get_current_device() |
||||
torch.cuda.set_device(self.device) |
||||
if self.verbose: |
||||
logger.info(f"the device is {self.device}") |
||||
|
||||
model = model.to(dtype=self.dtype, non_blocking=False).eval() |
||||
|
||||
if self.verbose: |
||||
logger.info( |
||||
f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}" |
||||
) |
||||
|
||||
if model_policy is None: |
||||
if self.inference_config.pad_input: |
||||
model_type = "padding_" + self.model_config.model_type |
||||
else: |
||||
model_type = "nopadding_" + self.model_config.model_type |
||||
model_policy = model_policy_map[model_type]() |
||||
|
||||
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size) |
||||
tp_group = pg_mesh.get_group_along_axis(TP_AXIS) |
||||
|
||||
self.model = self._shardformer( |
||||
model, |
||||
model_policy, |
||||
None, |
||||
tp_group=tp_group, |
||||
) |
||||
|
||||
self.model = ModelWrapper(model).to(device=get_accelerator().get_current_device()) |
||||
|
||||
if self.verbose: |
||||
logger.info( |
||||
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}" |
||||
) |
||||
|
||||
if isinstance(model_or_path, str) and is_local: |
||||
from colossalai.inference.core.plugin import InferCheckpoint_io |
||||
|
||||
cpt_io = InferCheckpoint_io() |
||||
if_has_index_file, model_index_file = has_index_file(model_or_path) |
||||
assert if_has_index_file, "the model path is invalid" |
||||
cpt_io.load_model(self.model, model_index_file) |
||||
|
||||
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() |
||||
peak_memory = init_gpu_memory - free_gpu_memory |
||||
if self.verbose: |
||||
logger.info( |
||||
f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB" |
||||
) |
||||
|
||||
def _shardformer( |
||||
self, |
||||
model: nn.Module, |
||||
model_policy: Policy, |
||||
stage_manager: PipelineStageManager = None, |
||||
tp_group: ProcessGroupMesh = None, |
||||
) -> nn.Module: |
||||
""" |
||||
Initialize ShardConfig and replace the model with shardformer. |
||||
|
||||
Args: |
||||
model (nn.Module): Path or nn.Module of this model. |
||||
model_policy (Policy): The policy to shardformer model which is determined by the model type. |
||||
stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None. |
||||
tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None. |
||||
|
||||
Returns: |
||||
nn.Module: The model optimized by Shardformer. |
||||
""" |
||||
|
||||
shardconfig = ShardConfig( |
||||
tensor_parallel_process_group=tp_group, |
||||
pipeline_stage_manager=stage_manager, |
||||
enable_tensor_parallelism=(self.inference_config.tp_size > 1), |
||||
enable_fused_normalization=False, |
||||
enable_all_optimization=False, |
||||
enable_flash_attention=False, |
||||
enable_jit_fused=False, |
||||
enable_sequence_parallelism=False, |
||||
) |
||||
shardformer = ShardFormer(shard_config=shardconfig) |
||||
shard_model, _ = shardformer.optimize(model, model_policy) |
||||
return shard_model |
||||
|
||||
def exposed_compute_only_for_test(self): |
||||
dist_rank = dist.get_rank() |
||||
|
||||
# Dummy data for each worker |
||||
data = torch.tensor([dist_rank], dtype=torch.float).cuda(dist_rank) |
||||
dist.barrier() |
||||
|
||||
# Perform distributed all_reduce |
||||
dist.all_reduce(data, op=dist.ReduceOp.SUM) |
||||
|
||||
dist.barrier() |
||||
logger.info(f"Worker rank {dist_rank}: Sum after all_reduce: {data.item()}") |
||||
|
||||
return data.item() |
@ -1,4 +1,4 @@
|
||||
from .block_cache import CacheBlock |
||||
from .kvcache_manager import KVCacheManager |
||||
from .kvcache_manager import KVCacheManager, RPCKVCacheManager |
||||
|
||||
__all__ = ["CacheBlock", "KVCacheManager"] |
||||
__all__ = ["CacheBlock", "KVCacheManager", "RPCKVCacheManager"] |
||||
|
@ -0,0 +1,105 @@
|
||||
import random |
||||
|
||||
import numpy as np |
||||
import pytest |
||||
import torch |
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig |
||||
|
||||
from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig |
||||
from colossalai.inference.core.rpc_engine import RPCInferenceEngine |
||||
from colossalai.inference.modeling.policy import NoPaddingLlamaModelInferPolicy |
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use |
||||
|
||||
|
||||
def setup_seed(seed): |
||||
torch.manual_seed(seed) |
||||
torch.random.manual_seed(seed) |
||||
torch.cuda.manual_seed_all(seed) |
||||
np.random.seed(seed) |
||||
random.seed(seed) |
||||
|
||||
|
||||
def check_inference_engine(tp_size, use_engine=False, prompt_template=None, do_sample=True, policy=None): |
||||
setup_seed(20) |
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") |
||||
model = "meta-llama/Llama-2-7b-hf" # remote mode path |
||||
inputs = [ |
||||
"介绍一下今天的北京,比如故宫,天安门,长城或者其他的一些景点,", |
||||
"介绍一下武汉,", |
||||
] |
||||
|
||||
output_len = 38 |
||||
top_p = 0.5 |
||||
top_k = 50 |
||||
|
||||
if use_engine: |
||||
inference_config = InferenceConfig( |
||||
max_output_len=output_len, |
||||
prompt_template=prompt_template, |
||||
dtype="fp32", |
||||
use_cuda_kernel=True, |
||||
tp_size=tp_size, |
||||
) |
||||
inference_engine = RPCInferenceEngine(model, tokenizer, inference_config, verbose=True, model_policy=policy) |
||||
assert inference_engine.generation_config.max_new_tokens == output_len |
||||
inference_engine.add_request(prompts=inputs) |
||||
assert inference_engine.request_handler._has_waiting() |
||||
generation_config = GenerationConfig( |
||||
max_new_tokens=output_len, do_sample=do_sample, dtype="fp32", top_p=top_p, top_k=top_k |
||||
) |
||||
outputs = inference_engine.generate(generation_config=generation_config) |
||||
else: |
||||
if prompt_template: |
||||
# apply prompt template |
||||
inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs] |
||||
model = AutoModelForCausalLM.from_pretrained(model).cuda() |
||||
tokenizer.pad_token = tokenizer.eos_token |
||||
tokenizer.pad_token_id = tokenizer.eos_token_id |
||||
inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"] |
||||
inputs = inputs.cuda() |
||||
generation_config = GenerationConfig( |
||||
do_sample=do_sample, |
||||
dtype="fp32", |
||||
top_p=top_p, |
||||
top_k=top_k, |
||||
pad_token_id=tokenizer.pad_token_id, |
||||
max_new_tokens=output_len, |
||||
) |
||||
outputs = model.generate(inputs, generation_config=generation_config) |
||||
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) |
||||
|
||||
return outputs |
||||
|
||||
|
||||
def run_engine(tp_size, **kwargs): |
||||
return check_inference_engine(tp_size=tp_size, **kwargs) |
||||
|
||||
|
||||
@pytest.mark.largedist |
||||
@parameterize("prompt_template", [None, "llama"]) |
||||
@parameterize("do_sample", [False]) |
||||
@rerun_if_address_is_in_use() |
||||
def test_tp_engine(prompt_template, do_sample): |
||||
if torch.multiprocessing.get_start_method(allow_none=True) is None: |
||||
torch.multiprocessing.set_start_method("spawn") |
||||
kwargs1 = { |
||||
"use_engine": True, |
||||
"prompt_template": prompt_template, |
||||
"do_sample": do_sample, |
||||
"policy": NoPaddingLlamaModelInferPolicy(), |
||||
} |
||||
|
||||
kwargs2 = {"use_engine": False, "prompt_template": prompt_template, "do_sample": do_sample, "policy": None} |
||||
|
||||
colossal_tp_1_output = run_engine(1, **kwargs1) |
||||
colossal_tp_2_output = run_engine(2, **kwargs1) |
||||
transformer_tp_1_output = run_engine(1, **kwargs2) |
||||
|
||||
for s1, s2, s3 in zip(colossal_tp_1_output, colossal_tp_2_output, transformer_tp_1_output): |
||||
assert s1 == s3, f"\nColossalAI TP=1 Output: {s1}\nTransformers Output: {s3}" |
||||
assert s1 == s2, f"\nColossalAI TP=1 Output: {s1}\nColossalAI TP=2 Output: {s2}" |
||||
|
||||
|
||||
if __name__ == "__main__": |
||||
torch.multiprocessing.set_start_method("spawn") # this code will not be ok for settings to fork to subprocess |
||||
test_tp_engine() |
Loading…
Reference in new issue