Browse Source

[Feat]Inference RPC Server Support (#5705)

* rpc support source
* kv cache logical/physical disaggregation
* sampler refactor
* colossalai launch built in
* Unitest
* Rpyc support

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/5716/head
Runyu Lu 6 months ago committed by GitHub
parent
commit
18d67d0e8e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 115
      colossalai/inference/config.py
  2. 17
      colossalai/inference/core/engine.py
  3. 95
      colossalai/inference/core/request_handler.py
  4. 291
      colossalai/inference/core/rpc_engine.py
  5. 300
      colossalai/inference/executor/rpc_worker.py
  6. 4
      colossalai/inference/kv_cache/__init__.py
  7. 77
      colossalai/inference/kv_cache/kvcache_manager.py
  8. 9
      colossalai/inference/logit_processors.py
  9. 10
      colossalai/inference/modeling/policy/nopadding_baichuan.py
  10. 10
      colossalai/inference/modeling/policy/nopadding_llama.py
  11. 49
      colossalai/inference/sampler.py
  12. 11
      colossalai/inference/utils.py
  13. 1
      requirements/requirements-test.txt
  14. 1
      requirements/requirements.txt
  15. 105
      tests/test_infer/test_rpc_engine.py

115
colossalai/inference/config.py

@ -2,11 +2,11 @@
Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference.
"""
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, List, Optional, Union
import torch
import torch.distributed as dist
from transformers.generation import GenerationConfig
from colossalai.inference.flash_decoding_utils import FDIntermTensors
@ -30,8 +30,25 @@ _DEFAULT_PROMPT_TEMPLATES = {
}
class RPC_PARAM(ABC):
"""
NOTE(lry89757) We use rpyc to transport param between client and server.
Rpyc only support the type of `POD` in python as the param, so we should take some smart ways to transport the data like tensor or some sophisticated classes.
Drawing on the logic of `__setstate__`, `__getstate__`, we will let some classes(will be rpc param later) inherit this base class, and rewrite the to_rpc_param and from_rpc_param. We will invoke `to_rpc_param` in client to pass the params and recover the param in server side by `from_rpc_param`.
"""
@abstractmethod
def to_rpc_param(self):
return NotImplementedError
@staticmethod
@abstractmethod
def from_rpc_param():
return NotImplementedError
@dataclass
class InputMetaData:
class InputMetaData(RPC_PARAM):
"""The input info for a single step
Args:
@ -48,6 +65,7 @@ class InputMetaData:
dtype (torch.dtype, optional): The computation type of tensor, Defaults to torch.float32.
use_spec_dec (bool): Indicate whether to use speculative decoding.
num_tokens_to_verify (int): The number of tokens to verify in speculative decoding. Only valid when `use_spec_dec` is set to True.
batch_token_ids (List[List[int]], optional): input_token_ids + output_token_ids of current batch. Only used for `repetition_penalty`, `no_repeat_ngram_size` in sampler process.
"""
block_tables: torch.Tensor = None
@ -63,6 +81,54 @@ class InputMetaData:
dtype: torch.dtype = torch.float32
use_spec_dec: bool = False
num_tokens_to_verify: int = 0
batch_token_ids: Optional[
List[List[int]]
] = None # for `repetition_penalty`, `no_repeat_ngram_size` in sampler process
def to_rpc_param(self) -> Dict[str, any]:
return {
"block_tables": self.block_tables.tolist(),
"sequence_lengths": self.sequence_lengths.tolist(),
"batch_size": self.batch_size,
"is_prompts": self.is_prompts,
"use_cuda_kernel": self.use_cuda_kernel,
"use_cuda_graph": self.use_cuda_graph,
"kv_seq_len": self.kv_seq_len,
"head_dim": self.head_dim,
"high_precision": self.high_precision,
"dtype": str(self.dtype).split(".")[-1],
"use_spec_dec": self.use_spec_dec,
"num_tokens_to_verify": self.num_tokens_to_verify,
"batch_token_ids": self.batch_token_ids,
}
@staticmethod
def from_rpc_param(rpc_dict: Dict[str, any]) -> "InputMetaData":
"""
We intentionally don't use `dict.get` method to ensure we pass the right rpc param, or program will show error message
"""
from colossalai.accelerator import get_accelerator
dtype = getattr(torch, rpc_dict["dtype"])
return InputMetaData(
block_tables=torch.tensor(
rpc_dict["block_tables"], dtype=torch.int, device=get_accelerator().get_current_device()
),
sequence_lengths=torch.tensor(
rpc_dict["sequence_lengths"], dtype=torch.int, device=get_accelerator().get_current_device()
),
batch_size=rpc_dict["batch_size"],
is_prompts=rpc_dict["is_prompts"],
use_cuda_kernel=rpc_dict["use_cuda_kernel"],
use_cuda_graph=rpc_dict["use_cuda_graph"],
kv_seq_len=rpc_dict["kv_seq_len"],
head_dim=rpc_dict["head_dim"],
high_precision=rpc_dict["high_precision"],
dtype=dtype,
use_spec_dec=rpc_dict["use_spec_dec"],
num_tokens_to_verify=rpc_dict["num_tokens_to_verify"],
batch_token_ids=rpc_dict["batch_token_ids"],
)
def __repr__(self) -> str:
return (
@ -80,7 +146,7 @@ class InputMetaData:
@dataclass
class InferenceConfig:
class InferenceConfig(RPC_PARAM):
"""The inference configuration.
Args:
@ -193,10 +259,6 @@ class InferenceConfig:
if self.dtype == torch.float32:
self.high_precision = False
# check distributed
assert (not torch.distributed.is_initialized() and self.tp_size * self.pp_size == 1) or (
self.tp_size * self.pp_size == dist.get_world_size()
), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})"
# check prompt template
if self.prompt_template is None:
return
@ -226,6 +288,43 @@ class InferenceConfig:
return GenerationConfig.from_dict(meta_config)
def to_rpc_param(self) -> dict:
kwargs = {
"dtype": str(self.dtype).split(".")[-1],
"max_n_spec_tokens": self.max_n_spec_tokens,
"max_batch_size": self.max_batch_size,
"max_input_len": self.max_input_len,
"max_output_len": self.max_output_len,
"tp_size": self.tp_size,
"pp_size": self.pp_size,
"pad_input": self.pad_input,
"early_stopping": self.early_stopping,
"do_sample": self.do_sample,
"beam_width": self.beam_width,
"kv_cache_dtype": str(self.kv_cache_dtype).split(".")[-1],
}
return kwargs
@staticmethod
def from_rpc_param(rpc_dict: dict) -> "InferenceConfig":
"""
We intentionally don't use `dict.get` method to ensure we pass the right rpc param, or program will show error message
"""
return InferenceConfig(
dtype=getattr(torch, rpc_dict["dtype"]),
max_n_spec_tokens=rpc_dict["max_n_spec_tokens"],
max_batch_size=rpc_dict["max_batch_size"],
max_input_len=rpc_dict["max_input_len"],
max_output_len=rpc_dict["max_output_len"],
tp_size=rpc_dict["tp_size"],
pp_size=rpc_dict["pp_size"],
pad_input=rpc_dict["pad_input"],
early_stopping=rpc_dict["early_stopping"],
do_sample=rpc_dict["do_sample"],
beam_width=rpc_dict["beam_width"],
kv_cache_dtype=getattr(torch, rpc_dict["kv_cache_dtype"], None),
)
@classmethod
def from_dict(cls, config_dict: Dict[str, Any]) -> "InferenceConfig":
# Get the list of attributes of this dataclass.

17
colossalai/inference/core/engine.py

@ -21,6 +21,7 @@ from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig, InputMetaData
from colossalai.inference.graph_runner import CUDAGraphRunner
from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.sampler import search_tokens
from colossalai.inference.spec import Drafter, GlideInput
from colossalai.inference.struct import Sequence
from colossalai.inference.utils import get_model_size, has_index_file
@ -424,7 +425,7 @@ class InferenceEngine:
# 2. Prefill main model (Verifier) - fill past kv cache for main model
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch)
next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
# append new inputs to the batch, temporarily
batch.append_batch_tokens(next_tokens)
self.request_handler.allocate_batch_spec_dec(batch, 1)
@ -472,7 +473,7 @@ class InferenceEngine:
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch)
next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
# 5. Compare and process the results
diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
@ -689,6 +690,13 @@ class InferenceEngine:
(n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
)
batch_token_ids = None
config_dict = self.generation_config.to_dict()
# process repetition_penalty, no_repeat_ngram_size
for type in ["repetition_penalty", "no_repeat_ngram_size"]:
if type in config_dict and config_dict[type] is not None:
batch_token_ids = batch.batch_token_ids
# only when we have the graph for specific decoding batch size can we use the cuda graph for inference
use_cuda_graph = False
if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys():
@ -708,6 +716,7 @@ class InferenceEngine:
dtype=batch.dtype,
use_spec_dec=batch.use_spec_dec,
num_tokens_to_verify=batch.num_tokens_to_verify,
batch_token_ids=batch_token_ids,
)
return input_ids, output_tensor, input_meta_data
@ -738,7 +747,9 @@ class InferenceEngine:
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
if self.inference_config.pad_input:
logits = logits[:, -1, :]
next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch)
next_tokens = search_tokens(
self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids
)
self.request_handler.append_next_tokens(next_tokens)
finished_sequences = self.request_handler.update()

95
colossalai/inference/core/request_handler.py

@ -7,10 +7,11 @@ from transformers.generation import GenerationConfig
from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.kv_cache import KVCacheManager
from colossalai.inference.logit_processors import logit_processor
from colossalai.inference.sampler import *
from colossalai.inference.kv_cache import KVCacheManager, RPCKVCacheManager
from colossalai.inference.struct import RequestStatus, Sequence
from colossalai.logging import get_dist_logger
logger = get_dist_logger(__name__)
__all__ = ["RunningList", "RequestHandler"]
@ -295,17 +296,6 @@ class RequestHandler:
return None
def _sample(self, probs: torch.Tensor, logprobs: torch.Tensor, generation_config: GenerationConfig):
if generation_config.num_beams == 1:
if generation_config.do_sample:
sample_tokens = multinomial_sample(generation_config, probs)
else:
sample_tokens = greedy_sample(generation_config, logprobs)
else:
sample_tokens = beam_search_sample(generation_config, logprobs, is_prompt=not self.prefill_bb.is_empty)
return sample_tokens
def update_seq_finished(self, sequence: Sequence, generation_config: GenerationConfig):
if (
sequence.output_token_id[-1] == generation_config.eos_token_id
@ -328,33 +318,6 @@ class RequestHandler:
def total_requests_in_batch_bucket(self) -> int:
return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size
def search_tokens(self, generation_config: GenerationConfig, logits, cur_batch: BatchBucket):
"""
Sample tokens for finished requests.
"""
# NOTE: need to decide the granularity to process logits (sequence or batch)
config_dict = generation_config.to_dict()
# process repetition_penalty, no_repeat_ngram_size
for type in ["repetition_penalty", "no_repeat_ngram_size"]:
if type in config_dict and config_dict[type] is not None:
logits = logit_processor(type, logits, config_dict[type], cur_batch)
# do logit processor
if generation_config.do_sample:
# process temperature, top_k, top_p
for type in ["temperature", "top_k", "top_p"]:
if type in config_dict and config_dict[type] is not None:
logits = logit_processor(type, logits, config_dict[type])
# calculate probs
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
# sample the next tokens
sample_tokens = self._sample(probs, logprobs, generation_config)
return sample_tokens
def append_next_tokens(self, sample_tokens: torch.Tensor):
assert sample_tokens.dim() == 1
n_elements = sample_tokens.size(0)
@ -386,3 +349,53 @@ class RequestHandler:
self.done_list.extend(finished_seqs)
return finished_seqs
class RPCRequestHandler(RequestHandler):
"""
RPC Version of request handler
"""
def __init__(self, inference_config: InferenceConfig, model_config: PretrainedConfig) -> None:
self.inference_config = inference_config
self.running_list: RunningList = RunningList(inference_config.prefill_ratio)
self.waiting_list: List[List] = [[], [], []]
self.done_list: List[Sequence] = []
self.dtype = inference_config.dtype
self.max_batch_size = inference_config.max_batch_size
# initialize cache
self._init_cache(model_config)
# initialize batch
torch.cuda.current_device()
kv_max_split_num = (
inference_config.max_input_len + inference_config.max_output_len + inference_config.block_size - 1
) // inference_config.block_size
head_dim = model_config.hidden_size // model_config.num_attention_heads
# TODO In the continuous batching scenario, the batch size may be greater than max_batch_size,
# which may cause bugs and this issue should be fixed later.
self.running_bb = BatchBucket(
num_heads=model_config.num_attention_heads // inference_config.tp_size,
head_dim=head_dim,
max_batch_size=self.max_batch_size,
max_length=inference_config.max_input_len + inference_config.max_output_len,
block_size=inference_config.block_size,
kv_max_split_num=kv_max_split_num,
fd_interm_tensor=None,
dtype=self.dtype,
)
self.prefill_bb = BatchBucket(
num_heads=model_config.num_attention_heads // inference_config.tp_size,
head_dim=head_dim,
max_batch_size=self.max_batch_size,
max_length=inference_config.max_input_len + inference_config.max_output_len,
block_size=inference_config.block_size,
kv_max_split_num=kv_max_split_num,
fd_interm_tensor=None,
dtype=self.dtype,
)
def _init_cache(self, model_config):
self.cache_manager = RPCKVCacheManager(self.inference_config, model_config)

291
colossalai/inference/core/rpc_engine.py

@ -0,0 +1,291 @@
import asyncio
from itertools import count
from time import sleep
from typing import List, Tuple, Union
import rpyc
import torch
import torch.nn as nn
from rpyc.utils.server import ThreadedServer
from torch import multiprocessing as mp
from transformers import AutoConfig, PreTrainedTokenizer, PreTrainedTokenizerFast
from transformers.configuration_utils import PretrainedConfig
from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig, InputMetaData
from colossalai.inference.executor.rpc_worker import rpcWorkerService
from colossalai.inference.utils import find_available_ports
from colossalai.logging import get_dist_logger
from colossalai.shardformer.policies.base_policy import Policy
from .engine import InferenceEngine
from .request_handler import RPCRequestHandler
__all__ = ["RPCInferenceEngine"]
def run_server(host, port, event: mp.Event = None):
server = ThreadedServer(
rpcWorkerService, port=port, protocol_config={"allow_public_attrs": True, "allow_all_attrs": True}
)
if event:
event.set()
server.start()
class RPCInferenceEngine(InferenceEngine):
"""
InferenceEngine which manages the inference process..
NOTE This `RPCInferenceEngine` is designed for multiple-card/online serving.
Original `InferenceEngine` is designed for single card and offline service, though it supports multi-card offline inference.
Args:
model_or_path (nn.Module or str): Path or nn.Module of this model, Currently we don't support `nn.Module` Format
tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use.
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
verbose (bool): Determine whether or not to log the generation process.
model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided.
"""
def __init__(
self,
model_or_path: Union[nn.Module, str],
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
inference_config: InferenceConfig,
verbose: bool = False,
model_policy: Policy = None,
) -> None:
"""
If you input a real model loaded by transformers, the init will take quite a long time
Currently we don't support model(nn.Module) format as the param.
"""
torch.multiprocessing.set_start_method("spawn", force=True)
self.inference_config = inference_config
self.tokenizer = tokenizer
self.tokenizer.pad_token = self.tokenizer.eos_token
self.verbose = verbose
self.logger = get_dist_logger(__name__)
try:
if isinstance(model_or_path, str):
self.model_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
elif isinstance(model_or_path, nn.Module):
self.logger.error(
f"An exception occurred during loading model Config: For {__class__.__name__}, we don't support param like nn.Module currently\n"
)
# self.model_config = model_or_path.config
else:
self.logger.error(
f"An exception occurred during loading model Config: Please pass right param for {__class__.__name__}\n"
)
except Exception as e:
self.logger.error(
f"An exception occurred during loading model Config: {e}, The path should be transformers-like\n"
)
self.generation_config = inference_config.to_generation_config(self.model_config)
self.tp_size = inference_config.tp_size
self.events = [mp.Event() for _ in range(self.tp_size)]
# This operation will init the dist env and models
self.workers: List[rpcWorkerService] = []
self.init_workers()
asyncio.run(self.init_model(model_or_path, model_policy))
# init the scheduler and logic block manager
self.request_handler = self.init_scheduler(self.inference_config, self.model_config)
# init the physical cache
alloc_shape = self.request_handler.cache_manager.get_physical_cache_shape()
self.init_device_cache(alloc_shape)
self.use_cuda_graph = self.inference_config.use_cuda_graph
self.high_precision = inference_config.high_precision
self.dtype = inference_config.dtype
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
self.use_spec_dec = False
self.drafter_model = None
self.drafter = None
self.use_glide = False
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
self.counter = count()
self._verify_args()
self.logger.info("engine init over ")
def _verify_args(self) -> None:
"""Verify the input args"""
if not isinstance(self.inference_config, InferenceConfig):
raise TypeError("Invalid type of inference config provided.")
if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)):
raise TypeError(
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}"
)
def init_workers(self):
rpc_ports = find_available_ports(self.tp_size)
self.worker_processes = []
# mp.set_start_method('spawn')
for event, rpc_port in zip(self.events, rpc_ports):
p = mp.Process(target=run_server, args=("localhost", rpc_port, event))
p.start()
self.worker_processes.append(p)
self.logger.info(f"Starting RPC Worker on localhost:{rpc_port}...")
# Wait for all servers to start
for event in self.events:
event.wait()
event.clear()
sleep(0.05)
self.logger.info(f"init rpc server done.")
for rpc_port in rpc_ports:
try:
conn = rpyc.connect(
"localhost",
rpc_port,
config={"allow_pickle": True, "allow_public_attrs": True, "allow_all_attrs": True},
)
self.workers.append(conn.root)
except:
raise Exception("conn error!")
self.logger.info(f"Build RPC Connection Success! Begin to load model...")
asyncio.run(self.init_worker_env())
self.logger.info(f"init dist env over")
async def async_parallel_wrapper(self, f, *args, **kwargs):
async_res = rpyc.async_(f)(*args, **kwargs)
await asyncio.to_thread(async_res.wait)
assert async_res.ready
return async_res.value
async def init_worker_env(self):
assert len(self.workers) == self.tp_size, "init workers first"
dist_group_port = find_available_ports(1)[0]
init_tasks = [
self.async_parallel_wrapper(
worker.init_dist_env, rank, self.inference_config.tp_size, "127.0.0.1", dist_group_port
)
for rank, worker in enumerate(self.workers)
]
await asyncio.gather(*init_tasks)
async def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None):
assert len(self.workers) == self.tp_size, "init workers first"
inference_config_param = self.inference_config.to_rpc_param()
model_path = model_or_path
model_policy_param = model_policy.to_rpc_param() if model_policy else None
init_tasks = [
self.async_parallel_wrapper(worker.init_model, inference_config_param, model_path, model_policy_param)
for rank, worker in enumerate(self.workers)
]
await asyncio.gather(*init_tasks)
def init_scheduler(self, inference_config: InferenceConfig, model_config: PretrainedConfig) -> RPCRequestHandler:
return RPCRequestHandler(inference_config, model_config)
async def _init_device_cache(self, alloc_shape: Tuple[int, int, int, int]):
assert len(self.workers) == self.tp_size, "init workers first"
init_tasks = [self.async_parallel_wrapper(worker.init_cache, alloc_shape) for worker in self.workers]
await asyncio.gather(*init_tasks)
def init_device_cache(self, alloc_shape: Tuple[Tuple[int, ...], Tuple[int, ...]]):
asyncio.run(self._init_device_cache(alloc_shape))
def prepare_input(self, batch: BatchBucket) -> Tuple[List[int], InputMetaData]:
input_ids = batch.get_1D_inputs()
sequence_lengths = batch.get_sequence_lengths()
if batch.is_prompts:
n_tokens = sequence_lengths.sum().item()
else:
n_tokens = batch.current_batch_size
if batch.use_spec_dec:
n_tokens = batch.num_tokens_to_verify + 1
assert n_tokens == input_ids.size(0)
n_tokens = n_tokens * batch.current_batch_size
batch_token_ids = None
config_dict = self.generation_config.to_dict()
# process repetition_penalty, no_repeat_ngram_size
for type in ["repetition_penalty", "no_repeat_ngram_size"]:
if type in config_dict and config_dict[type] is not None:
batch_token_ids = batch.batch_token_ids
# only when we have the graph for specific decoding batch size can we use the cuda graph for inference
use_cuda_graph = False
if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys():
use_cuda_graph = True
input_meta_data = InputMetaData(
block_tables=batch.get_block_table_tensor(),
sequence_lengths=sequence_lengths,
fd_inter_tensor=None,
batch_size=batch.current_batch_size,
is_prompts=batch.is_prompts,
use_cuda_kernel=self.inference_config.use_cuda_kernel,
use_cuda_graph=use_cuda_graph,
high_precision=self.high_precision,
kv_seq_len=sequence_lengths.max().item(),
head_dim=batch.head_dim,
dtype=batch.dtype,
use_spec_dec=batch.use_spec_dec,
num_tokens_to_verify=batch.num_tokens_to_verify,
batch_token_ids=batch_token_ids,
)
return input_ids.tolist(), input_meta_data
async def step_(self, input_token_ids, input_meta_data: InputMetaData):
assert len(self.workers) == self.tp_size, "init workers first"
init_tasks = [
self.async_parallel_wrapper(worker.execute_model_forward, input_token_ids, input_meta_data.to_rpc_param())
for worker in self.workers
]
ret = await asyncio.gather(*init_tasks)
return ret[0]
def step(self) -> List[str]:
batch = self.request_handler.schedule()
input_token_ids, input_meta_data = self.prepare_input(batch)
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
next_tokens = asyncio.run(self.step_(input_token_ids, input_meta_data))
# update the request_handler
next_tokens = torch.tensor(next_tokens, dtype=torch.int)
self.request_handler.append_next_tokens(next_tokens)
finished_sequences = self.request_handler.update()
return finished_sequences
def kill_workers(self):
"""
I don't find a good way to implicit invoke self.kill_workers
"""
assert len(self.workers) != 0
for proc in self.worker_processes:
proc.kill()
proc.join()
self.logger.info(f"worker killed, serving end")
def __del__(self):
self.kill_workers()

300
colossalai/inference/executor/rpc_worker.py

@ -0,0 +1,300 @@
import os
from typing import List, Tuple, Union
import rpyc
import torch
import torch.distributed as dist
from torch import nn
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.models.llama.modeling_llama import LlamaForCausalLM
import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.config import InferenceConfig, InputMetaData
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.modeling.policy import (
NoPaddingBaichuanModelInferPolicy,
NoPaddingLlamaModelInferPolicy,
model_policy_map,
)
from colossalai.inference.sampler import search_tokens
from colossalai.inference.utils import get_model_size, has_index_file
from colossalai.interface import ModelWrapper
from colossalai.logging import get_dist_logger
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.policies.base_policy import Policy
PP_AXIS, TP_AXIS = 0, 1
_SUPPORTED_MODELS = {
"LlamaForCausalLM": LlamaForCausalLM,
"BaichuanForCausalLM": AutoModelForCausalLM,
}
_SUPPORTED_MODEL_POLICIES = {
"NoPaddingLlamaModelInferPolicy": NoPaddingLlamaModelInferPolicy,
"NoPaddingBaichuanModelInferPolicy": NoPaddingBaichuanModelInferPolicy,
}
logger = get_dist_logger(__name__)
class rpcWorkerService(rpyc.Service):
"""
Execute the computation tasks and manage its own kv cache
Func with prefix `exposed_` will be invoked by client.
"""
def exposed_init_dist_env(self, rank, world_size, master_address, master_port):
logger.info(f"init process group for rank {rank}")
colossalai.launch(rank=rank, world_size=world_size, port=master_port, host=master_address)
logger.info(f"init process group done for rank {rank}")
def exposed_init_model(
self, inference_config_param: dict, model_or_path: Union[nn.Module, str], model_policy_param: str = None
):
assert dist.is_initialized(), "invoke init_dist_env first please!"
self.inference_config = InferenceConfig.from_rpc_param(inference_config_param)
model_policy = _SUPPORTED_MODEL_POLICIES[model_policy_param]() if model_policy_param else None
self.dtype = self.inference_config.dtype
self.verbose = True
self._init_model(model_or_path, model_policy)
self._init_fd_tensor()
self._init_output_tensor()
logger.info(f"init model done for rank {dist.get_rank()}")
def exposed_init_cache(self, alloc_shape: Tuple[Tuple[int, ...], Tuple[int, ...]]):
"""Initialize the physical cache on the device.
For each layer of the model, we allocate two tensors for key and value respectively,
with shape of [num_blocks, num_kv_heads, block_size, head_size]
"""
kalloc_shape, valloc_shape = alloc_shape
num_layers = self.model_config.num_hidden_layers
self.k_cache: List[torch.Tensor] = []
self.v_cache: List[torch.Tensor] = []
for _ in range(num_layers):
self.k_cache.append(
torch.zeros(
kalloc_shape,
dtype=self.inference_config.kv_cache_dtype,
device=get_accelerator().get_current_device(),
)
)
self.v_cache.append(
torch.zeros(
valloc_shape,
dtype=self.inference_config.kv_cache_dtype,
device=get_accelerator().get_current_device(),
)
)
logger.info("physical cache init over")
def exposed_execute_model_forward(self, input_token_ids_param: List[int], input_meta_data_param: dict):
# prepare the data for model forward
input_meta_data = InputMetaData.from_rpc_param(input_meta_data_param)
input_meta_data.fd_inter_tensor = self.fd_inter_tensor
if input_meta_data.is_prompts:
n_tokens = input_meta_data.sequence_lengths.sum().item()
else:
n_tokens = input_meta_data.batch_size
input_token_ids = torch.tensor(input_token_ids_param, dtype=torch.int, device=self.device)
# execute the model
logits = self.model(
input_token_ids,
self.output_tensor[:n_tokens],
input_meta_data,
self.k_cache,
self.v_cache,
)
# sampler
if self.inference_config.pad_input:
logits = logits[:, -1, :]
next_tokens = search_tokens(
self.inference_config.to_generation_config(self.model_config),
logits,
input_meta_data.is_prompts,
input_meta_data.batch_token_ids,
)
# return the tokens generated to scheduler
return next_tokens.tolist()
def _init_output_tensor(self):
alloc_shape = (
self.inference_config.max_batch_size
* (self.inference_config.max_input_len + self.inference_config.max_output_len),
self.model_config.hidden_size // self.inference_config.tp_size,
)
self.output_tensor = torch.zeros(alloc_shape, dtype=self.dtype, device=self.device)
def _init_fd_tensor(self):
fd_inter_tensor = FDIntermTensors()
if fd_inter_tensor._tensors_initialized:
fd_inter_tensor._reset()
# For Spec-Dec, process the speculated tokens plus the token in the last step for each seq
max_n_tokens = self.inference_config.max_batch_size
max_n_tokens *= self.inference_config.max_n_spec_tokens + 1
inference_config = self.inference_config
kv_max_split_num = (
inference_config.max_input_len + inference_config.max_output_len + inference_config.block_size - 1
) // inference_config.block_size
head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads
fd_inter_tensor.initialize(
max_batch_size=max_n_tokens,
num_attn_heads=self.model_config.num_attention_heads // self.inference_config.tp_size,
kv_max_split_num=kv_max_split_num,
head_dim=head_dim,
dtype=self.dtype,
device=get_accelerator().get_current_device(),
)
self.fd_inter_tensor = fd_inter_tensor
def _init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None):
"""
Shard model or/and Load weight
Shard model: When we set tp_size > 1, we will shard the model by given model_policy.
Load Weight: If we pass a local model path, we will load the model weight by checkpoint_io. If it is a remote-transformer url, we will use `AutoModel.from_pretrained` api of transformers lib
Args:
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
model_policy (Policy): the policy to replace the model
"""
if isinstance(model_or_path, str):
is_local = os.path.isdir(model_or_path)
try:
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
arch = getattr(hf_config, "architectures")[0]
if is_local:
model = _SUPPORTED_MODELS[arch](hf_config)
else:
# load the real checkpoint
model = _SUPPORTED_MODELS[arch].from_pretrained(model_or_path, trust_remote_code=True)
except Exception as e:
logger.error(
f"An exception occurred during loading model: {e}, model should be loaded by transformers\n"
)
else:
model = model_or_path
self.model_config = model.config
torch.cuda.empty_cache()
init_gpu_memory = torch.cuda.mem_get_info()[0]
self.device = get_accelerator().get_current_device()
torch.cuda.set_device(self.device)
if self.verbose:
logger.info(f"the device is {self.device}")
model = model.to(dtype=self.dtype, non_blocking=False).eval()
if self.verbose:
logger.info(
f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}"
)
if model_policy is None:
if self.inference_config.pad_input:
model_type = "padding_" + self.model_config.model_type
else:
model_type = "nopadding_" + self.model_config.model_type
model_policy = model_policy_map[model_type]()
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
self.model = self._shardformer(
model,
model_policy,
None,
tp_group=tp_group,
)
self.model = ModelWrapper(model).to(device=get_accelerator().get_current_device())
if self.verbose:
logger.info(
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
)
if isinstance(model_or_path, str) and is_local:
from colossalai.inference.core.plugin import InferCheckpoint_io
cpt_io = InferCheckpoint_io()
if_has_index_file, model_index_file = has_index_file(model_or_path)
assert if_has_index_file, "the model path is invalid"
cpt_io.load_model(self.model, model_index_file)
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
peak_memory = init_gpu_memory - free_gpu_memory
if self.verbose:
logger.info(
f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB"
)
def _shardformer(
self,
model: nn.Module,
model_policy: Policy,
stage_manager: PipelineStageManager = None,
tp_group: ProcessGroupMesh = None,
) -> nn.Module:
"""
Initialize ShardConfig and replace the model with shardformer.
Args:
model (nn.Module): Path or nn.Module of this model.
model_policy (Policy): The policy to shardformer model which is determined by the model type.
stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None.
tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None.
Returns:
nn.Module: The model optimized by Shardformer.
"""
shardconfig = ShardConfig(
tensor_parallel_process_group=tp_group,
pipeline_stage_manager=stage_manager,
enable_tensor_parallelism=(self.inference_config.tp_size > 1),
enable_fused_normalization=False,
enable_all_optimization=False,
enable_flash_attention=False,
enable_jit_fused=False,
enable_sequence_parallelism=False,
)
shardformer = ShardFormer(shard_config=shardconfig)
shard_model, _ = shardformer.optimize(model, model_policy)
return shard_model
def exposed_compute_only_for_test(self):
dist_rank = dist.get_rank()
# Dummy data for each worker
data = torch.tensor([dist_rank], dtype=torch.float).cuda(dist_rank)
dist.barrier()
# Perform distributed all_reduce
dist.all_reduce(data, op=dist.ReduceOp.SUM)
dist.barrier()
logger.info(f"Worker rank {dist_rank}: Sum after all_reduce: {data.item()}")
return data.item()

4
colossalai/inference/kv_cache/__init__.py vendored

@ -1,4 +1,4 @@
from .block_cache import CacheBlock
from .kvcache_manager import KVCacheManager
from .kvcache_manager import KVCacheManager, RPCKVCacheManager
__all__ = ["CacheBlock", "KVCacheManager"]
__all__ = ["CacheBlock", "KVCacheManager", "RPCKVCacheManager"]

77
colossalai/inference/kv_cache/kvcache_manager.py vendored

@ -497,3 +497,80 @@ class KVCacheManager:
k_cache.append(torch.zeros(kalloc_shape, dtype=self.kv_cache_dtype, device=self.device))
v_cache.append(torch.zeros(valloc_shape, dtype=self.kv_cache_dtype, device=self.device))
return k_cache, v_cache
class RPCKVCacheManager(KVCacheManager):
def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verbose: bool = False) -> None:
self.logger = get_dist_logger(__name__)
self.device = get_current_device()
self.config = config
# Parallel settings
self.tp_size = config.tp_size
# Model settings
self.dtype = config.dtype
self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size()
self.num_layers = model_config.num_hidden_layers
self.head_num = model_config.num_attention_heads
self.head_size = model_config.hidden_size // self.head_num
if hasattr(model_config, "num_key_value_heads"):
self.kv_head_num = model_config.num_key_value_heads
else:
self.kv_head_num = self.head_num
if config.kv_cache_dtype is None:
self.kv_cache_dtype = config.dtype
else:
self.kv_cache_dtype = config.kv_cache_dtype
assert (
self.kv_head_num % self.tp_size == 0
), f"Cannot shard {self.kv_head_num} heads with tp size {self.tp_size}"
self.kv_head_num //= self.tp_size
self.beam_width = config.beam_width
self.max_batch_size = config.max_batch_size
self.max_input_length = config.max_input_len
self.max_output_length = config.max_output_len
# Cache block settings
self.block_size = config.block_size
# NOTE: `num_blocks` is not prompted, but evaluated from the maximum input/output length, and the maximum batch size
self.max_blocks_per_sequence = (
self.max_input_length + self.max_output_length + self.block_size - 1
) // self.block_size
self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width
# Logical cache blocks allocation
self._available_blocks = self.num_blocks
self._cache_blocks = tuple(self._init_logical_caches())
# block availablity state 0->allocated, 1->free
self._block_states = torch.ones((self.num_blocks,), dtype=torch.bool)
self._block_states_cum = torch.zeros(size=(self.num_blocks + 1,), dtype=torch.int64)
self._block_finder = torch.zeros((self.num_blocks,), dtype=torch.int64)
def get_physical_cache_shape(self) -> Tuple[Tuple[int, ...], Tuple[int, ...]]:
# Physical cache allocation
if self.config.use_cuda_kernel:
x = 16 // torch.tensor([], dtype=self.config.dtype).element_size()
kalloc_shape = (self.num_blocks, self.kv_head_num, self.head_size // x, self.block_size, x)
valloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size)
self.logger.info(
f"Allocating K cache with shape: {kalloc_shape}, V cache with shape: {valloc_shape} consisting of {self.num_blocks} blocks."
)
else:
alloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size)
kalloc_shape = alloc_shape
valloc_shape = alloc_shape
self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
return kalloc_shape, valloc_shape
def get_kv_cache(self):
"""Get k_cache and v_cache"""
return NotImplementedError
def _init_logical_caches(self):
"""Initialize the logical cache blocks."""
blocks = []
for i in range(self.num_blocks):
cache_block = CacheBlock(i, self.block_size, self.elem_size_in_bytes, k_ptrs=None, v_ptrs=None)
blocks.append(cache_block)
return blocks

9
colossalai/inference/logit_processors.py

@ -1,10 +1,9 @@
# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/generation/logits_process.py
from typing import List
import torch
import torch.nn.functional as F
from colossalai.inference.batch_bucket import BatchBucket
_LOGIT_PROCESSOR_MAP = {}
@ -22,7 +21,7 @@ def register_logit_processor(process_type):
@register_logit_processor("no_repeat_ngram_size")
def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch: BatchBucket):
def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch_token_ids: List[List[int]]):
"""
enforces no repetition of n-grams to avoid repetitions of word sequences.
"""
@ -31,7 +30,6 @@ def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch: BatchBuck
raise ValueError(f"'temperature={ngram_size}' should be a strictly positive integer.")
if ngram_size != 0:
batch_token_ids = batch.batch_token_ids
batch_size = len(batch_token_ids)
for batch_id in range(batch_size):
@ -55,7 +53,7 @@ def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch: BatchBuck
@register_logit_processor("repetition_penalty")
def repetition_penalty_logit_process(logits, penalty: float, batch: BatchBucket):
def repetition_penalty_logit_process(logits, penalty: float, batch_token_ids: List[List[int]]):
"""
apply the penalty to the tokens present in the prompt.
"""
@ -67,7 +65,6 @@ def repetition_penalty_logit_process(logits, penalty: float, batch: BatchBucket)
# TODO(yuehuayingxueluo) This is only a temporary implementation. Later, we will implement presence_penalties, frequency_penalties, and repetition_penalties using CUDA kernels.
if penalty != 1.0:
batch_token_ids = batch.batch_token_ids
for batch_id in range(len(batch_token_ids)):
current_logit = logits[batch_id]
current_token = torch.tensor(batch_token_ids[batch_id], dtype=torch.long, device=logits.device)

10
colossalai/inference/modeling/policy/nopadding_baichuan.py

@ -1,3 +1,4 @@
from colossalai.inference.config import RPC_PARAM
from colossalai.inference.modeling.layers.baichuan_tp_linear import (
BaichuanLMHeadLinear1D_Col,
BaichuanWpackLinear1D_Col,
@ -18,7 +19,7 @@ from colossalai.shardformer.policies.base_policy import ModulePolicyDescription,
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy):
class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM):
def __init__(self) -> None:
super().__init__()
@ -100,3 +101,10 @@ class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy):
def postprocess(self):
init_to_get_rotary(self.model.model)
return self.model
def to_rpc_param(self) -> str:
return __class__.__name__
@staticmethod
def from_rpc_param() -> "NoPaddingBaichuanModelInferPolicy":
return NoPaddingBaichuanModelInferPolicy()

10
colossalai/inference/modeling/policy/nopadding_llama.py

@ -1,5 +1,6 @@
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm
from colossalai.inference.config import RPC_PARAM
from colossalai.inference.modeling.models.nopadding_llama import (
NopadLlamaAttention,
NopadLlamaMLP,
@ -14,7 +15,7 @@ from colossalai.shardformer.policies.base_policy import ModulePolicyDescription,
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM):
def __init__(self) -> None:
super().__init__()
@ -102,3 +103,10 @@ class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
def postprocess(self):
init_to_get_rotary(self.model.model, self.model.config.rope_theta)
return self.model
def to_rpc_param(self) -> str:
return __class__.__name__
@staticmethod
def from_rpc_param() -> "NoPaddingLlamaModelInferPolicy":
return NoPaddingLlamaModelInferPolicy()

49
colossalai/inference/sampler.py

@ -1,6 +1,9 @@
from typing import List, Tuple
from typing import List, Optional, Tuple
import torch
from transformers.generation import GenerationConfig
from colossalai.inference.logit_processors import logit_processor
def greedy_sample(
@ -59,3 +62,47 @@ def beam_search_sample(
results.append((next_token_ids, parent_ids))
return results
def _sample(probs: torch.Tensor, logprobs: torch.Tensor, generation_config: GenerationConfig, is_prompt: bool = False):
if generation_config.num_beams == 1:
if generation_config.do_sample:
sample_tokens = multinomial_sample(generation_config, probs)
else:
sample_tokens = greedy_sample(generation_config, logprobs)
else:
sample_tokens = beam_search_sample(generation_config, logprobs, is_prompt=is_prompt)
return sample_tokens
def search_tokens(
generation_config: GenerationConfig,
logits,
is_prompt: bool = False,
batch_token_ids: Optional[List[List[int]]] = None,
):
"""
Sample tokens for finished requests.
"""
# NOTE: need to decide the granularity to process logits (sequence or batch)
config_dict = generation_config.to_dict()
# process repetition_penalty, no_repeat_ngram_size
for type in ["repetition_penalty", "no_repeat_ngram_size"]:
if type in config_dict and config_dict[type] is not None:
logits = logit_processor(type, logits, config_dict[type], batch_token_ids)
# do logit processor
if generation_config.do_sample:
# process temperature, top_k, top_p
for type in ["temperature", "top_k", "top_p"]:
if type in config_dict and config_dict[type] is not None:
logits = logit_processor(type, logits, config_dict[type])
# calculate probs
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
# sample the next tokens
sample_tokens = _sample(probs, logprobs, generation_config, is_prompt)
return sample_tokens

11
colossalai/inference/utils.py

@ -9,6 +9,8 @@ from typing import Optional, Tuple
import torch
from torch import nn
from colossalai.testing import free_port
def init_to_get_rotary(self, base=10000, use_elem=False):
"""
@ -102,3 +104,12 @@ def get_model_size(model: nn.Module):
for key, param in model.named_parameters():
total_size += param.element_size() * param.numel()
return total_size / (1024**3)
def find_available_ports(num: int):
try:
free_ports = [free_port() for i in range(num)]
except OSError as e:
print(f"An OS error occurred: {e}")
raise RuntimeError("Error finding available ports")
return free_ports

1
requirements/requirements-test.txt

@ -19,4 +19,5 @@ datasets
pydantic
ray
peft>=0.7.1
rpyc==6.0.0
#auto-gptq now not support torch1.12

1
requirements/requirements.txt

@ -19,3 +19,4 @@ protobuf
transformers==4.36.2
peft>=0.7.1
bitsandbytes>=0.39.0
rpyc==6.0.0

105
tests/test_infer/test_rpc_engine.py

@ -0,0 +1,105 @@
import random
import numpy as np
import pytest
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig
from colossalai.inference.core.rpc_engine import RPCInferenceEngine
from colossalai.inference.modeling.policy import NoPaddingLlamaModelInferPolicy
from colossalai.testing import parameterize, rerun_if_address_is_in_use
def setup_seed(seed):
torch.manual_seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
def check_inference_engine(tp_size, use_engine=False, prompt_template=None, do_sample=True, policy=None):
setup_seed(20)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
model = "meta-llama/Llama-2-7b-hf" # remote mode path
inputs = [
"介绍一下今天的北京,比如故宫,天安门,长城或者其他的一些景点,",
"介绍一下武汉,",
]
output_len = 38
top_p = 0.5
top_k = 50
if use_engine:
inference_config = InferenceConfig(
max_output_len=output_len,
prompt_template=prompt_template,
dtype="fp32",
use_cuda_kernel=True,
tp_size=tp_size,
)
inference_engine = RPCInferenceEngine(model, tokenizer, inference_config, verbose=True, model_policy=policy)
assert inference_engine.generation_config.max_new_tokens == output_len
inference_engine.add_request(prompts=inputs)
assert inference_engine.request_handler._has_waiting()
generation_config = GenerationConfig(
max_new_tokens=output_len, do_sample=do_sample, dtype="fp32", top_p=top_p, top_k=top_k
)
outputs = inference_engine.generate(generation_config=generation_config)
else:
if prompt_template:
# apply prompt template
inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs]
model = AutoModelForCausalLM.from_pretrained(model).cuda()
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
inputs = inputs.cuda()
generation_config = GenerationConfig(
do_sample=do_sample,
dtype="fp32",
top_p=top_p,
top_k=top_k,
pad_token_id=tokenizer.pad_token_id,
max_new_tokens=output_len,
)
outputs = model.generate(inputs, generation_config=generation_config)
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
return outputs
def run_engine(tp_size, **kwargs):
return check_inference_engine(tp_size=tp_size, **kwargs)
@pytest.mark.largedist
@parameterize("prompt_template", [None, "llama"])
@parameterize("do_sample", [False])
@rerun_if_address_is_in_use()
def test_tp_engine(prompt_template, do_sample):
if torch.multiprocessing.get_start_method(allow_none=True) is None:
torch.multiprocessing.set_start_method("spawn")
kwargs1 = {
"use_engine": True,
"prompt_template": prompt_template,
"do_sample": do_sample,
"policy": NoPaddingLlamaModelInferPolicy(),
}
kwargs2 = {"use_engine": False, "prompt_template": prompt_template, "do_sample": do_sample, "policy": None}
colossal_tp_1_output = run_engine(1, **kwargs1)
colossal_tp_2_output = run_engine(2, **kwargs1)
transformer_tp_1_output = run_engine(1, **kwargs2)
for s1, s2, s3 in zip(colossal_tp_1_output, colossal_tp_2_output, transformer_tp_1_output):
assert s1 == s3, f"\nColossalAI TP=1 Output: {s1}\nTransformers Output: {s3}"
assert s1 == s2, f"\nColossalAI TP=1 Output: {s1}\nColossalAI TP=2 Output: {s2}"
if __name__ == "__main__":
torch.multiprocessing.set_start_method("spawn") # this code will not be ok for settings to fork to subprocess
test_tp_engine()
Loading…
Cancel
Save