mirror of https://github.com/hpcaitech/ColossalAI
[Feat]Inference RPC Server Support (#5705)
* rpc support source * kv cache logical/physical disaggregation * sampler refactor * colossalai launch built in * Unitest * Rpyc support --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/5716/head
parent
de4bf3dedf
commit
18d67d0e8e
|
@ -2,11 +2,11 @@
|
||||||
Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference.
|
Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference.
|
||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, fields
|
from dataclasses import dataclass, fields
|
||||||
from typing import Any, Dict, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
|
||||||
from transformers.generation import GenerationConfig
|
from transformers.generation import GenerationConfig
|
||||||
|
|
||||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||||
|
@ -30,8 +30,25 @@ _DEFAULT_PROMPT_TEMPLATES = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class RPC_PARAM(ABC):
|
||||||
|
"""
|
||||||
|
NOTE(lry89757) We use rpyc to transport param between client and server.
|
||||||
|
Rpyc only support the type of `POD` in python as the param, so we should take some smart ways to transport the data like tensor or some sophisticated classes.
|
||||||
|
Drawing on the logic of `__setstate__`, `__getstate__`, we will let some classes(will be rpc param later) inherit this base class, and rewrite the to_rpc_param and from_rpc_param. We will invoke `to_rpc_param` in client to pass the params and recover the param in server side by `from_rpc_param`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def to_rpc_param(self):
|
||||||
|
return NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def from_rpc_param():
|
||||||
|
return NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InputMetaData:
|
class InputMetaData(RPC_PARAM):
|
||||||
"""The input info for a single step
|
"""The input info for a single step
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -48,6 +65,7 @@ class InputMetaData:
|
||||||
dtype (torch.dtype, optional): The computation type of tensor, Defaults to torch.float32.
|
dtype (torch.dtype, optional): The computation type of tensor, Defaults to torch.float32.
|
||||||
use_spec_dec (bool): Indicate whether to use speculative decoding.
|
use_spec_dec (bool): Indicate whether to use speculative decoding.
|
||||||
num_tokens_to_verify (int): The number of tokens to verify in speculative decoding. Only valid when `use_spec_dec` is set to True.
|
num_tokens_to_verify (int): The number of tokens to verify in speculative decoding. Only valid when `use_spec_dec` is set to True.
|
||||||
|
batch_token_ids (List[List[int]], optional): input_token_ids + output_token_ids of current batch. Only used for `repetition_penalty`, `no_repeat_ngram_size` in sampler process.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
block_tables: torch.Tensor = None
|
block_tables: torch.Tensor = None
|
||||||
|
@ -63,6 +81,54 @@ class InputMetaData:
|
||||||
dtype: torch.dtype = torch.float32
|
dtype: torch.dtype = torch.float32
|
||||||
use_spec_dec: bool = False
|
use_spec_dec: bool = False
|
||||||
num_tokens_to_verify: int = 0
|
num_tokens_to_verify: int = 0
|
||||||
|
batch_token_ids: Optional[
|
||||||
|
List[List[int]]
|
||||||
|
] = None # for `repetition_penalty`, `no_repeat_ngram_size` in sampler process
|
||||||
|
|
||||||
|
def to_rpc_param(self) -> Dict[str, any]:
|
||||||
|
return {
|
||||||
|
"block_tables": self.block_tables.tolist(),
|
||||||
|
"sequence_lengths": self.sequence_lengths.tolist(),
|
||||||
|
"batch_size": self.batch_size,
|
||||||
|
"is_prompts": self.is_prompts,
|
||||||
|
"use_cuda_kernel": self.use_cuda_kernel,
|
||||||
|
"use_cuda_graph": self.use_cuda_graph,
|
||||||
|
"kv_seq_len": self.kv_seq_len,
|
||||||
|
"head_dim": self.head_dim,
|
||||||
|
"high_precision": self.high_precision,
|
||||||
|
"dtype": str(self.dtype).split(".")[-1],
|
||||||
|
"use_spec_dec": self.use_spec_dec,
|
||||||
|
"num_tokens_to_verify": self.num_tokens_to_verify,
|
||||||
|
"batch_token_ids": self.batch_token_ids,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_rpc_param(rpc_dict: Dict[str, any]) -> "InputMetaData":
|
||||||
|
"""
|
||||||
|
We intentionally don't use `dict.get` method to ensure we pass the right rpc param, or program will show error message
|
||||||
|
"""
|
||||||
|
from colossalai.accelerator import get_accelerator
|
||||||
|
|
||||||
|
dtype = getattr(torch, rpc_dict["dtype"])
|
||||||
|
return InputMetaData(
|
||||||
|
block_tables=torch.tensor(
|
||||||
|
rpc_dict["block_tables"], dtype=torch.int, device=get_accelerator().get_current_device()
|
||||||
|
),
|
||||||
|
sequence_lengths=torch.tensor(
|
||||||
|
rpc_dict["sequence_lengths"], dtype=torch.int, device=get_accelerator().get_current_device()
|
||||||
|
),
|
||||||
|
batch_size=rpc_dict["batch_size"],
|
||||||
|
is_prompts=rpc_dict["is_prompts"],
|
||||||
|
use_cuda_kernel=rpc_dict["use_cuda_kernel"],
|
||||||
|
use_cuda_graph=rpc_dict["use_cuda_graph"],
|
||||||
|
kv_seq_len=rpc_dict["kv_seq_len"],
|
||||||
|
head_dim=rpc_dict["head_dim"],
|
||||||
|
high_precision=rpc_dict["high_precision"],
|
||||||
|
dtype=dtype,
|
||||||
|
use_spec_dec=rpc_dict["use_spec_dec"],
|
||||||
|
num_tokens_to_verify=rpc_dict["num_tokens_to_verify"],
|
||||||
|
batch_token_ids=rpc_dict["batch_token_ids"],
|
||||||
|
)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (
|
return (
|
||||||
|
@ -80,7 +146,7 @@ class InputMetaData:
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InferenceConfig:
|
class InferenceConfig(RPC_PARAM):
|
||||||
"""The inference configuration.
|
"""The inference configuration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -193,10 +259,6 @@ class InferenceConfig:
|
||||||
if self.dtype == torch.float32:
|
if self.dtype == torch.float32:
|
||||||
self.high_precision = False
|
self.high_precision = False
|
||||||
|
|
||||||
# check distributed
|
|
||||||
assert (not torch.distributed.is_initialized() and self.tp_size * self.pp_size == 1) or (
|
|
||||||
self.tp_size * self.pp_size == dist.get_world_size()
|
|
||||||
), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})"
|
|
||||||
# check prompt template
|
# check prompt template
|
||||||
if self.prompt_template is None:
|
if self.prompt_template is None:
|
||||||
return
|
return
|
||||||
|
@ -226,6 +288,43 @@ class InferenceConfig:
|
||||||
|
|
||||||
return GenerationConfig.from_dict(meta_config)
|
return GenerationConfig.from_dict(meta_config)
|
||||||
|
|
||||||
|
def to_rpc_param(self) -> dict:
|
||||||
|
kwargs = {
|
||||||
|
"dtype": str(self.dtype).split(".")[-1],
|
||||||
|
"max_n_spec_tokens": self.max_n_spec_tokens,
|
||||||
|
"max_batch_size": self.max_batch_size,
|
||||||
|
"max_input_len": self.max_input_len,
|
||||||
|
"max_output_len": self.max_output_len,
|
||||||
|
"tp_size": self.tp_size,
|
||||||
|
"pp_size": self.pp_size,
|
||||||
|
"pad_input": self.pad_input,
|
||||||
|
"early_stopping": self.early_stopping,
|
||||||
|
"do_sample": self.do_sample,
|
||||||
|
"beam_width": self.beam_width,
|
||||||
|
"kv_cache_dtype": str(self.kv_cache_dtype).split(".")[-1],
|
||||||
|
}
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_rpc_param(rpc_dict: dict) -> "InferenceConfig":
|
||||||
|
"""
|
||||||
|
We intentionally don't use `dict.get` method to ensure we pass the right rpc param, or program will show error message
|
||||||
|
"""
|
||||||
|
return InferenceConfig(
|
||||||
|
dtype=getattr(torch, rpc_dict["dtype"]),
|
||||||
|
max_n_spec_tokens=rpc_dict["max_n_spec_tokens"],
|
||||||
|
max_batch_size=rpc_dict["max_batch_size"],
|
||||||
|
max_input_len=rpc_dict["max_input_len"],
|
||||||
|
max_output_len=rpc_dict["max_output_len"],
|
||||||
|
tp_size=rpc_dict["tp_size"],
|
||||||
|
pp_size=rpc_dict["pp_size"],
|
||||||
|
pad_input=rpc_dict["pad_input"],
|
||||||
|
early_stopping=rpc_dict["early_stopping"],
|
||||||
|
do_sample=rpc_dict["do_sample"],
|
||||||
|
beam_width=rpc_dict["beam_width"],
|
||||||
|
kv_cache_dtype=getattr(torch, rpc_dict["kv_cache_dtype"], None),
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, config_dict: Dict[str, Any]) -> "InferenceConfig":
|
def from_dict(cls, config_dict: Dict[str, Any]) -> "InferenceConfig":
|
||||||
# Get the list of attributes of this dataclass.
|
# Get the list of attributes of this dataclass.
|
||||||
|
|
|
@ -21,6 +21,7 @@ from colossalai.inference.batch_bucket import BatchBucket
|
||||||
from colossalai.inference.config import InferenceConfig, InputMetaData
|
from colossalai.inference.config import InferenceConfig, InputMetaData
|
||||||
from colossalai.inference.graph_runner import CUDAGraphRunner
|
from colossalai.inference.graph_runner import CUDAGraphRunner
|
||||||
from colossalai.inference.modeling.policy import model_policy_map
|
from colossalai.inference.modeling.policy import model_policy_map
|
||||||
|
from colossalai.inference.sampler import search_tokens
|
||||||
from colossalai.inference.spec import Drafter, GlideInput
|
from colossalai.inference.spec import Drafter, GlideInput
|
||||||
from colossalai.inference.struct import Sequence
|
from colossalai.inference.struct import Sequence
|
||||||
from colossalai.inference.utils import get_model_size, has_index_file
|
from colossalai.inference.utils import get_model_size, has_index_file
|
||||||
|
@ -424,7 +425,7 @@ class InferenceEngine:
|
||||||
|
|
||||||
# 2. Prefill main model (Verifier) - fill past kv cache for main model
|
# 2. Prefill main model (Verifier) - fill past kv cache for main model
|
||||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||||
next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch)
|
next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
|
||||||
# append new inputs to the batch, temporarily
|
# append new inputs to the batch, temporarily
|
||||||
batch.append_batch_tokens(next_tokens)
|
batch.append_batch_tokens(next_tokens)
|
||||||
self.request_handler.allocate_batch_spec_dec(batch, 1)
|
self.request_handler.allocate_batch_spec_dec(batch, 1)
|
||||||
|
@ -472,7 +473,7 @@ class InferenceEngine:
|
||||||
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
|
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
|
||||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||||
|
|
||||||
next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch)
|
next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
|
||||||
|
|
||||||
# 5. Compare and process the results
|
# 5. Compare and process the results
|
||||||
diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
|
diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
|
||||||
|
@ -689,6 +690,13 @@ class InferenceEngine:
|
||||||
(n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
|
(n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
batch_token_ids = None
|
||||||
|
config_dict = self.generation_config.to_dict()
|
||||||
|
# process repetition_penalty, no_repeat_ngram_size
|
||||||
|
for type in ["repetition_penalty", "no_repeat_ngram_size"]:
|
||||||
|
if type in config_dict and config_dict[type] is not None:
|
||||||
|
batch_token_ids = batch.batch_token_ids
|
||||||
|
|
||||||
# only when we have the graph for specific decoding batch size can we use the cuda graph for inference
|
# only when we have the graph for specific decoding batch size can we use the cuda graph for inference
|
||||||
use_cuda_graph = False
|
use_cuda_graph = False
|
||||||
if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys():
|
if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys():
|
||||||
|
@ -708,6 +716,7 @@ class InferenceEngine:
|
||||||
dtype=batch.dtype,
|
dtype=batch.dtype,
|
||||||
use_spec_dec=batch.use_spec_dec,
|
use_spec_dec=batch.use_spec_dec,
|
||||||
num_tokens_to_verify=batch.num_tokens_to_verify,
|
num_tokens_to_verify=batch.num_tokens_to_verify,
|
||||||
|
batch_token_ids=batch_token_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
return input_ids, output_tensor, input_meta_data
|
return input_ids, output_tensor, input_meta_data
|
||||||
|
@ -738,7 +747,9 @@ class InferenceEngine:
|
||||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||||
if self.inference_config.pad_input:
|
if self.inference_config.pad_input:
|
||||||
logits = logits[:, -1, :]
|
logits = logits[:, -1, :]
|
||||||
next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch)
|
next_tokens = search_tokens(
|
||||||
|
self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids
|
||||||
|
)
|
||||||
self.request_handler.append_next_tokens(next_tokens)
|
self.request_handler.append_next_tokens(next_tokens)
|
||||||
finished_sequences = self.request_handler.update()
|
finished_sequences = self.request_handler.update()
|
||||||
|
|
||||||
|
|
|
@ -7,10 +7,11 @@ from transformers.generation import GenerationConfig
|
||||||
from colossalai.inference.batch_bucket import BatchBucket
|
from colossalai.inference.batch_bucket import BatchBucket
|
||||||
from colossalai.inference.config import InferenceConfig
|
from colossalai.inference.config import InferenceConfig
|
||||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||||
from colossalai.inference.kv_cache import KVCacheManager
|
from colossalai.inference.kv_cache import KVCacheManager, RPCKVCacheManager
|
||||||
from colossalai.inference.logit_processors import logit_processor
|
|
||||||
from colossalai.inference.sampler import *
|
|
||||||
from colossalai.inference.struct import RequestStatus, Sequence
|
from colossalai.inference.struct import RequestStatus, Sequence
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
|
|
||||||
|
logger = get_dist_logger(__name__)
|
||||||
|
|
||||||
__all__ = ["RunningList", "RequestHandler"]
|
__all__ = ["RunningList", "RequestHandler"]
|
||||||
|
|
||||||
|
@ -295,17 +296,6 @@ class RequestHandler:
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _sample(self, probs: torch.Tensor, logprobs: torch.Tensor, generation_config: GenerationConfig):
|
|
||||||
if generation_config.num_beams == 1:
|
|
||||||
if generation_config.do_sample:
|
|
||||||
sample_tokens = multinomial_sample(generation_config, probs)
|
|
||||||
else:
|
|
||||||
sample_tokens = greedy_sample(generation_config, logprobs)
|
|
||||||
else:
|
|
||||||
sample_tokens = beam_search_sample(generation_config, logprobs, is_prompt=not self.prefill_bb.is_empty)
|
|
||||||
|
|
||||||
return sample_tokens
|
|
||||||
|
|
||||||
def update_seq_finished(self, sequence: Sequence, generation_config: GenerationConfig):
|
def update_seq_finished(self, sequence: Sequence, generation_config: GenerationConfig):
|
||||||
if (
|
if (
|
||||||
sequence.output_token_id[-1] == generation_config.eos_token_id
|
sequence.output_token_id[-1] == generation_config.eos_token_id
|
||||||
|
@ -328,33 +318,6 @@ class RequestHandler:
|
||||||
def total_requests_in_batch_bucket(self) -> int:
|
def total_requests_in_batch_bucket(self) -> int:
|
||||||
return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size
|
return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size
|
||||||
|
|
||||||
def search_tokens(self, generation_config: GenerationConfig, logits, cur_batch: BatchBucket):
|
|
||||||
"""
|
|
||||||
Sample tokens for finished requests.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# NOTE: need to decide the granularity to process logits (sequence or batch)
|
|
||||||
config_dict = generation_config.to_dict()
|
|
||||||
# process repetition_penalty, no_repeat_ngram_size
|
|
||||||
for type in ["repetition_penalty", "no_repeat_ngram_size"]:
|
|
||||||
if type in config_dict and config_dict[type] is not None:
|
|
||||||
logits = logit_processor(type, logits, config_dict[type], cur_batch)
|
|
||||||
|
|
||||||
# do logit processor
|
|
||||||
if generation_config.do_sample:
|
|
||||||
# process temperature, top_k, top_p
|
|
||||||
for type in ["temperature", "top_k", "top_p"]:
|
|
||||||
if type in config_dict and config_dict[type] is not None:
|
|
||||||
logits = logit_processor(type, logits, config_dict[type])
|
|
||||||
|
|
||||||
# calculate probs
|
|
||||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
|
||||||
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
|
||||||
|
|
||||||
# sample the next tokens
|
|
||||||
sample_tokens = self._sample(probs, logprobs, generation_config)
|
|
||||||
return sample_tokens
|
|
||||||
|
|
||||||
def append_next_tokens(self, sample_tokens: torch.Tensor):
|
def append_next_tokens(self, sample_tokens: torch.Tensor):
|
||||||
assert sample_tokens.dim() == 1
|
assert sample_tokens.dim() == 1
|
||||||
n_elements = sample_tokens.size(0)
|
n_elements = sample_tokens.size(0)
|
||||||
|
@ -386,3 +349,53 @@ class RequestHandler:
|
||||||
self.done_list.extend(finished_seqs)
|
self.done_list.extend(finished_seqs)
|
||||||
|
|
||||||
return finished_seqs
|
return finished_seqs
|
||||||
|
|
||||||
|
|
||||||
|
class RPCRequestHandler(RequestHandler):
|
||||||
|
"""
|
||||||
|
RPC Version of request handler
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, inference_config: InferenceConfig, model_config: PretrainedConfig) -> None:
|
||||||
|
self.inference_config = inference_config
|
||||||
|
self.running_list: RunningList = RunningList(inference_config.prefill_ratio)
|
||||||
|
self.waiting_list: List[List] = [[], [], []]
|
||||||
|
self.done_list: List[Sequence] = []
|
||||||
|
self.dtype = inference_config.dtype
|
||||||
|
self.max_batch_size = inference_config.max_batch_size
|
||||||
|
|
||||||
|
# initialize cache
|
||||||
|
self._init_cache(model_config)
|
||||||
|
|
||||||
|
# initialize batch
|
||||||
|
torch.cuda.current_device()
|
||||||
|
kv_max_split_num = (
|
||||||
|
inference_config.max_input_len + inference_config.max_output_len + inference_config.block_size - 1
|
||||||
|
) // inference_config.block_size
|
||||||
|
head_dim = model_config.hidden_size // model_config.num_attention_heads
|
||||||
|
|
||||||
|
# TODO In the continuous batching scenario, the batch size may be greater than max_batch_size,
|
||||||
|
# which may cause bugs and this issue should be fixed later.
|
||||||
|
self.running_bb = BatchBucket(
|
||||||
|
num_heads=model_config.num_attention_heads // inference_config.tp_size,
|
||||||
|
head_dim=head_dim,
|
||||||
|
max_batch_size=self.max_batch_size,
|
||||||
|
max_length=inference_config.max_input_len + inference_config.max_output_len,
|
||||||
|
block_size=inference_config.block_size,
|
||||||
|
kv_max_split_num=kv_max_split_num,
|
||||||
|
fd_interm_tensor=None,
|
||||||
|
dtype=self.dtype,
|
||||||
|
)
|
||||||
|
self.prefill_bb = BatchBucket(
|
||||||
|
num_heads=model_config.num_attention_heads // inference_config.tp_size,
|
||||||
|
head_dim=head_dim,
|
||||||
|
max_batch_size=self.max_batch_size,
|
||||||
|
max_length=inference_config.max_input_len + inference_config.max_output_len,
|
||||||
|
block_size=inference_config.block_size,
|
||||||
|
kv_max_split_num=kv_max_split_num,
|
||||||
|
fd_interm_tensor=None,
|
||||||
|
dtype=self.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _init_cache(self, model_config):
|
||||||
|
self.cache_manager = RPCKVCacheManager(self.inference_config, model_config)
|
||||||
|
|
|
@ -0,0 +1,291 @@
|
||||||
|
import asyncio
|
||||||
|
from itertools import count
|
||||||
|
from time import sleep
|
||||||
|
from typing import List, Tuple, Union
|
||||||
|
|
||||||
|
import rpyc
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from rpyc.utils.server import ThreadedServer
|
||||||
|
from torch import multiprocessing as mp
|
||||||
|
from transformers import AutoConfig, PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
from colossalai.inference.batch_bucket import BatchBucket
|
||||||
|
from colossalai.inference.config import InferenceConfig, InputMetaData
|
||||||
|
from colossalai.inference.executor.rpc_worker import rpcWorkerService
|
||||||
|
from colossalai.inference.utils import find_available_ports
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
|
from colossalai.shardformer.policies.base_policy import Policy
|
||||||
|
|
||||||
|
from .engine import InferenceEngine
|
||||||
|
from .request_handler import RPCRequestHandler
|
||||||
|
|
||||||
|
__all__ = ["RPCInferenceEngine"]
|
||||||
|
|
||||||
|
|
||||||
|
def run_server(host, port, event: mp.Event = None):
|
||||||
|
server = ThreadedServer(
|
||||||
|
rpcWorkerService, port=port, protocol_config={"allow_public_attrs": True, "allow_all_attrs": True}
|
||||||
|
)
|
||||||
|
if event:
|
||||||
|
event.set()
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
|
||||||
|
class RPCInferenceEngine(InferenceEngine):
|
||||||
|
|
||||||
|
"""
|
||||||
|
InferenceEngine which manages the inference process..
|
||||||
|
|
||||||
|
NOTE This `RPCInferenceEngine` is designed for multiple-card/online serving.
|
||||||
|
Original `InferenceEngine` is designed for single card and offline service, though it supports multi-card offline inference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_or_path (nn.Module or str): Path or nn.Module of this model, Currently we don't support `nn.Module` Format
|
||||||
|
tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use.
|
||||||
|
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
|
||||||
|
verbose (bool): Determine whether or not to log the generation process.
|
||||||
|
model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_or_path: Union[nn.Module, str],
|
||||||
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||||
|
inference_config: InferenceConfig,
|
||||||
|
verbose: bool = False,
|
||||||
|
model_policy: Policy = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
If you input a real model loaded by transformers, the init will take quite a long time
|
||||||
|
Currently we don't support model(nn.Module) format as the param.
|
||||||
|
"""
|
||||||
|
|
||||||
|
torch.multiprocessing.set_start_method("spawn", force=True)
|
||||||
|
|
||||||
|
self.inference_config = inference_config
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||||
|
|
||||||
|
self.verbose = verbose
|
||||||
|
self.logger = get_dist_logger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if isinstance(model_or_path, str):
|
||||||
|
self.model_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
|
||||||
|
elif isinstance(model_or_path, nn.Module):
|
||||||
|
self.logger.error(
|
||||||
|
f"An exception occurred during loading model Config: For {__class__.__name__}, we don't support param like nn.Module currently\n"
|
||||||
|
)
|
||||||
|
# self.model_config = model_or_path.config
|
||||||
|
else:
|
||||||
|
self.logger.error(
|
||||||
|
f"An exception occurred during loading model Config: Please pass right param for {__class__.__name__}\n"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(
|
||||||
|
f"An exception occurred during loading model Config: {e}, The path should be transformers-like\n"
|
||||||
|
)
|
||||||
|
self.generation_config = inference_config.to_generation_config(self.model_config)
|
||||||
|
|
||||||
|
self.tp_size = inference_config.tp_size
|
||||||
|
self.events = [mp.Event() for _ in range(self.tp_size)]
|
||||||
|
|
||||||
|
# This operation will init the dist env and models
|
||||||
|
self.workers: List[rpcWorkerService] = []
|
||||||
|
self.init_workers()
|
||||||
|
|
||||||
|
asyncio.run(self.init_model(model_or_path, model_policy))
|
||||||
|
|
||||||
|
# init the scheduler and logic block manager
|
||||||
|
self.request_handler = self.init_scheduler(self.inference_config, self.model_config)
|
||||||
|
|
||||||
|
# init the physical cache
|
||||||
|
alloc_shape = self.request_handler.cache_manager.get_physical_cache_shape()
|
||||||
|
self.init_device_cache(alloc_shape)
|
||||||
|
|
||||||
|
self.use_cuda_graph = self.inference_config.use_cuda_graph
|
||||||
|
self.high_precision = inference_config.high_precision
|
||||||
|
self.dtype = inference_config.dtype
|
||||||
|
|
||||||
|
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
|
||||||
|
self.use_spec_dec = False
|
||||||
|
self.drafter_model = None
|
||||||
|
self.drafter = None
|
||||||
|
self.use_glide = False
|
||||||
|
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
|
||||||
|
|
||||||
|
self.counter = count()
|
||||||
|
self._verify_args()
|
||||||
|
|
||||||
|
self.logger.info("engine init over ")
|
||||||
|
|
||||||
|
def _verify_args(self) -> None:
|
||||||
|
"""Verify the input args"""
|
||||||
|
if not isinstance(self.inference_config, InferenceConfig):
|
||||||
|
raise TypeError("Invalid type of inference config provided.")
|
||||||
|
if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)):
|
||||||
|
raise TypeError(
|
||||||
|
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_workers(self):
|
||||||
|
rpc_ports = find_available_ports(self.tp_size)
|
||||||
|
self.worker_processes = []
|
||||||
|
# mp.set_start_method('spawn')
|
||||||
|
for event, rpc_port in zip(self.events, rpc_ports):
|
||||||
|
p = mp.Process(target=run_server, args=("localhost", rpc_port, event))
|
||||||
|
p.start()
|
||||||
|
self.worker_processes.append(p)
|
||||||
|
self.logger.info(f"Starting RPC Worker on localhost:{rpc_port}...")
|
||||||
|
|
||||||
|
# Wait for all servers to start
|
||||||
|
for event in self.events:
|
||||||
|
event.wait()
|
||||||
|
event.clear()
|
||||||
|
|
||||||
|
sleep(0.05)
|
||||||
|
|
||||||
|
self.logger.info(f"init rpc server done.")
|
||||||
|
|
||||||
|
for rpc_port in rpc_ports:
|
||||||
|
try:
|
||||||
|
conn = rpyc.connect(
|
||||||
|
"localhost",
|
||||||
|
rpc_port,
|
||||||
|
config={"allow_pickle": True, "allow_public_attrs": True, "allow_all_attrs": True},
|
||||||
|
)
|
||||||
|
self.workers.append(conn.root)
|
||||||
|
except:
|
||||||
|
raise Exception("conn error!")
|
||||||
|
self.logger.info(f"Build RPC Connection Success! Begin to load model...")
|
||||||
|
asyncio.run(self.init_worker_env())
|
||||||
|
self.logger.info(f"init dist env over")
|
||||||
|
|
||||||
|
async def async_parallel_wrapper(self, f, *args, **kwargs):
|
||||||
|
async_res = rpyc.async_(f)(*args, **kwargs)
|
||||||
|
await asyncio.to_thread(async_res.wait)
|
||||||
|
assert async_res.ready
|
||||||
|
return async_res.value
|
||||||
|
|
||||||
|
async def init_worker_env(self):
|
||||||
|
assert len(self.workers) == self.tp_size, "init workers first"
|
||||||
|
|
||||||
|
dist_group_port = find_available_ports(1)[0]
|
||||||
|
init_tasks = [
|
||||||
|
self.async_parallel_wrapper(
|
||||||
|
worker.init_dist_env, rank, self.inference_config.tp_size, "127.0.0.1", dist_group_port
|
||||||
|
)
|
||||||
|
for rank, worker in enumerate(self.workers)
|
||||||
|
]
|
||||||
|
|
||||||
|
await asyncio.gather(*init_tasks)
|
||||||
|
|
||||||
|
async def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None):
|
||||||
|
assert len(self.workers) == self.tp_size, "init workers first"
|
||||||
|
|
||||||
|
inference_config_param = self.inference_config.to_rpc_param()
|
||||||
|
model_path = model_or_path
|
||||||
|
model_policy_param = model_policy.to_rpc_param() if model_policy else None
|
||||||
|
|
||||||
|
init_tasks = [
|
||||||
|
self.async_parallel_wrapper(worker.init_model, inference_config_param, model_path, model_policy_param)
|
||||||
|
for rank, worker in enumerate(self.workers)
|
||||||
|
]
|
||||||
|
|
||||||
|
await asyncio.gather(*init_tasks)
|
||||||
|
|
||||||
|
def init_scheduler(self, inference_config: InferenceConfig, model_config: PretrainedConfig) -> RPCRequestHandler:
|
||||||
|
return RPCRequestHandler(inference_config, model_config)
|
||||||
|
|
||||||
|
async def _init_device_cache(self, alloc_shape: Tuple[int, int, int, int]):
|
||||||
|
assert len(self.workers) == self.tp_size, "init workers first"
|
||||||
|
|
||||||
|
init_tasks = [self.async_parallel_wrapper(worker.init_cache, alloc_shape) for worker in self.workers]
|
||||||
|
|
||||||
|
await asyncio.gather(*init_tasks)
|
||||||
|
|
||||||
|
def init_device_cache(self, alloc_shape: Tuple[Tuple[int, ...], Tuple[int, ...]]):
|
||||||
|
asyncio.run(self._init_device_cache(alloc_shape))
|
||||||
|
|
||||||
|
def prepare_input(self, batch: BatchBucket) -> Tuple[List[int], InputMetaData]:
|
||||||
|
input_ids = batch.get_1D_inputs()
|
||||||
|
sequence_lengths = batch.get_sequence_lengths()
|
||||||
|
|
||||||
|
if batch.is_prompts:
|
||||||
|
n_tokens = sequence_lengths.sum().item()
|
||||||
|
else:
|
||||||
|
n_tokens = batch.current_batch_size
|
||||||
|
if batch.use_spec_dec:
|
||||||
|
n_tokens = batch.num_tokens_to_verify + 1
|
||||||
|
assert n_tokens == input_ids.size(0)
|
||||||
|
n_tokens = n_tokens * batch.current_batch_size
|
||||||
|
|
||||||
|
batch_token_ids = None
|
||||||
|
config_dict = self.generation_config.to_dict()
|
||||||
|
# process repetition_penalty, no_repeat_ngram_size
|
||||||
|
for type in ["repetition_penalty", "no_repeat_ngram_size"]:
|
||||||
|
if type in config_dict and config_dict[type] is not None:
|
||||||
|
batch_token_ids = batch.batch_token_ids
|
||||||
|
|
||||||
|
# only when we have the graph for specific decoding batch size can we use the cuda graph for inference
|
||||||
|
use_cuda_graph = False
|
||||||
|
if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys():
|
||||||
|
use_cuda_graph = True
|
||||||
|
|
||||||
|
input_meta_data = InputMetaData(
|
||||||
|
block_tables=batch.get_block_table_tensor(),
|
||||||
|
sequence_lengths=sequence_lengths,
|
||||||
|
fd_inter_tensor=None,
|
||||||
|
batch_size=batch.current_batch_size,
|
||||||
|
is_prompts=batch.is_prompts,
|
||||||
|
use_cuda_kernel=self.inference_config.use_cuda_kernel,
|
||||||
|
use_cuda_graph=use_cuda_graph,
|
||||||
|
high_precision=self.high_precision,
|
||||||
|
kv_seq_len=sequence_lengths.max().item(),
|
||||||
|
head_dim=batch.head_dim,
|
||||||
|
dtype=batch.dtype,
|
||||||
|
use_spec_dec=batch.use_spec_dec,
|
||||||
|
num_tokens_to_verify=batch.num_tokens_to_verify,
|
||||||
|
batch_token_ids=batch_token_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
return input_ids.tolist(), input_meta_data
|
||||||
|
|
||||||
|
async def step_(self, input_token_ids, input_meta_data: InputMetaData):
|
||||||
|
assert len(self.workers) == self.tp_size, "init workers first"
|
||||||
|
|
||||||
|
init_tasks = [
|
||||||
|
self.async_parallel_wrapper(worker.execute_model_forward, input_token_ids, input_meta_data.to_rpc_param())
|
||||||
|
for worker in self.workers
|
||||||
|
]
|
||||||
|
ret = await asyncio.gather(*init_tasks)
|
||||||
|
|
||||||
|
return ret[0]
|
||||||
|
|
||||||
|
def step(self) -> List[str]:
|
||||||
|
batch = self.request_handler.schedule()
|
||||||
|
|
||||||
|
input_token_ids, input_meta_data = self.prepare_input(batch)
|
||||||
|
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
|
||||||
|
next_tokens = asyncio.run(self.step_(input_token_ids, input_meta_data))
|
||||||
|
|
||||||
|
# update the request_handler
|
||||||
|
next_tokens = torch.tensor(next_tokens, dtype=torch.int)
|
||||||
|
self.request_handler.append_next_tokens(next_tokens)
|
||||||
|
finished_sequences = self.request_handler.update()
|
||||||
|
return finished_sequences
|
||||||
|
|
||||||
|
def kill_workers(self):
|
||||||
|
"""
|
||||||
|
I don't find a good way to implicit invoke self.kill_workers
|
||||||
|
"""
|
||||||
|
assert len(self.workers) != 0
|
||||||
|
for proc in self.worker_processes:
|
||||||
|
proc.kill()
|
||||||
|
proc.join()
|
||||||
|
self.logger.info(f"worker killed, serving end")
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
self.kill_workers()
|
|
@ -0,0 +1,300 @@
|
||||||
|
import os
|
||||||
|
from typing import List, Tuple, Union
|
||||||
|
|
||||||
|
import rpyc
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch import nn
|
||||||
|
from transformers import AutoConfig, AutoModelForCausalLM
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.accelerator import get_accelerator
|
||||||
|
from colossalai.cluster import ProcessGroupMesh
|
||||||
|
from colossalai.inference.config import InferenceConfig, InputMetaData
|
||||||
|
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||||
|
from colossalai.inference.modeling.policy import (
|
||||||
|
NoPaddingBaichuanModelInferPolicy,
|
||||||
|
NoPaddingLlamaModelInferPolicy,
|
||||||
|
model_policy_map,
|
||||||
|
)
|
||||||
|
from colossalai.inference.sampler import search_tokens
|
||||||
|
from colossalai.inference.utils import get_model_size, has_index_file
|
||||||
|
from colossalai.interface import ModelWrapper
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||||
|
from colossalai.shardformer.policies.base_policy import Policy
|
||||||
|
|
||||||
|
PP_AXIS, TP_AXIS = 0, 1
|
||||||
|
|
||||||
|
_SUPPORTED_MODELS = {
|
||||||
|
"LlamaForCausalLM": LlamaForCausalLM,
|
||||||
|
"BaichuanForCausalLM": AutoModelForCausalLM,
|
||||||
|
}
|
||||||
|
|
||||||
|
_SUPPORTED_MODEL_POLICIES = {
|
||||||
|
"NoPaddingLlamaModelInferPolicy": NoPaddingLlamaModelInferPolicy,
|
||||||
|
"NoPaddingBaichuanModelInferPolicy": NoPaddingBaichuanModelInferPolicy,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger = get_dist_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class rpcWorkerService(rpyc.Service):
|
||||||
|
|
||||||
|
"""
|
||||||
|
Execute the computation tasks and manage its own kv cache
|
||||||
|
|
||||||
|
Func with prefix `exposed_` will be invoked by client.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def exposed_init_dist_env(self, rank, world_size, master_address, master_port):
|
||||||
|
logger.info(f"init process group for rank {rank}")
|
||||||
|
colossalai.launch(rank=rank, world_size=world_size, port=master_port, host=master_address)
|
||||||
|
logger.info(f"init process group done for rank {rank}")
|
||||||
|
|
||||||
|
def exposed_init_model(
|
||||||
|
self, inference_config_param: dict, model_or_path: Union[nn.Module, str], model_policy_param: str = None
|
||||||
|
):
|
||||||
|
assert dist.is_initialized(), "invoke init_dist_env first please!"
|
||||||
|
|
||||||
|
self.inference_config = InferenceConfig.from_rpc_param(inference_config_param)
|
||||||
|
model_policy = _SUPPORTED_MODEL_POLICIES[model_policy_param]() if model_policy_param else None
|
||||||
|
|
||||||
|
self.dtype = self.inference_config.dtype
|
||||||
|
self.verbose = True
|
||||||
|
|
||||||
|
self._init_model(model_or_path, model_policy)
|
||||||
|
self._init_fd_tensor()
|
||||||
|
self._init_output_tensor()
|
||||||
|
logger.info(f"init model done for rank {dist.get_rank()}")
|
||||||
|
|
||||||
|
def exposed_init_cache(self, alloc_shape: Tuple[Tuple[int, ...], Tuple[int, ...]]):
|
||||||
|
"""Initialize the physical cache on the device.
|
||||||
|
|
||||||
|
For each layer of the model, we allocate two tensors for key and value respectively,
|
||||||
|
with shape of [num_blocks, num_kv_heads, block_size, head_size]
|
||||||
|
"""
|
||||||
|
kalloc_shape, valloc_shape = alloc_shape
|
||||||
|
num_layers = self.model_config.num_hidden_layers
|
||||||
|
|
||||||
|
self.k_cache: List[torch.Tensor] = []
|
||||||
|
self.v_cache: List[torch.Tensor] = []
|
||||||
|
for _ in range(num_layers):
|
||||||
|
self.k_cache.append(
|
||||||
|
torch.zeros(
|
||||||
|
kalloc_shape,
|
||||||
|
dtype=self.inference_config.kv_cache_dtype,
|
||||||
|
device=get_accelerator().get_current_device(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.v_cache.append(
|
||||||
|
torch.zeros(
|
||||||
|
valloc_shape,
|
||||||
|
dtype=self.inference_config.kv_cache_dtype,
|
||||||
|
device=get_accelerator().get_current_device(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.info("physical cache init over")
|
||||||
|
|
||||||
|
def exposed_execute_model_forward(self, input_token_ids_param: List[int], input_meta_data_param: dict):
|
||||||
|
# prepare the data for model forward
|
||||||
|
input_meta_data = InputMetaData.from_rpc_param(input_meta_data_param)
|
||||||
|
input_meta_data.fd_inter_tensor = self.fd_inter_tensor
|
||||||
|
if input_meta_data.is_prompts:
|
||||||
|
n_tokens = input_meta_data.sequence_lengths.sum().item()
|
||||||
|
else:
|
||||||
|
n_tokens = input_meta_data.batch_size
|
||||||
|
input_token_ids = torch.tensor(input_token_ids_param, dtype=torch.int, device=self.device)
|
||||||
|
|
||||||
|
# execute the model
|
||||||
|
logits = self.model(
|
||||||
|
input_token_ids,
|
||||||
|
self.output_tensor[:n_tokens],
|
||||||
|
input_meta_data,
|
||||||
|
self.k_cache,
|
||||||
|
self.v_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
# sampler
|
||||||
|
if self.inference_config.pad_input:
|
||||||
|
logits = logits[:, -1, :]
|
||||||
|
next_tokens = search_tokens(
|
||||||
|
self.inference_config.to_generation_config(self.model_config),
|
||||||
|
logits,
|
||||||
|
input_meta_data.is_prompts,
|
||||||
|
input_meta_data.batch_token_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
# return the tokens generated to scheduler
|
||||||
|
return next_tokens.tolist()
|
||||||
|
|
||||||
|
def _init_output_tensor(self):
|
||||||
|
alloc_shape = (
|
||||||
|
self.inference_config.max_batch_size
|
||||||
|
* (self.inference_config.max_input_len + self.inference_config.max_output_len),
|
||||||
|
self.model_config.hidden_size // self.inference_config.tp_size,
|
||||||
|
)
|
||||||
|
self.output_tensor = torch.zeros(alloc_shape, dtype=self.dtype, device=self.device)
|
||||||
|
|
||||||
|
def _init_fd_tensor(self):
|
||||||
|
fd_inter_tensor = FDIntermTensors()
|
||||||
|
|
||||||
|
if fd_inter_tensor._tensors_initialized:
|
||||||
|
fd_inter_tensor._reset()
|
||||||
|
|
||||||
|
# For Spec-Dec, process the speculated tokens plus the token in the last step for each seq
|
||||||
|
max_n_tokens = self.inference_config.max_batch_size
|
||||||
|
max_n_tokens *= self.inference_config.max_n_spec_tokens + 1
|
||||||
|
|
||||||
|
inference_config = self.inference_config
|
||||||
|
kv_max_split_num = (
|
||||||
|
inference_config.max_input_len + inference_config.max_output_len + inference_config.block_size - 1
|
||||||
|
) // inference_config.block_size
|
||||||
|
head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads
|
||||||
|
|
||||||
|
fd_inter_tensor.initialize(
|
||||||
|
max_batch_size=max_n_tokens,
|
||||||
|
num_attn_heads=self.model_config.num_attention_heads // self.inference_config.tp_size,
|
||||||
|
kv_max_split_num=kv_max_split_num,
|
||||||
|
head_dim=head_dim,
|
||||||
|
dtype=self.dtype,
|
||||||
|
device=get_accelerator().get_current_device(),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.fd_inter_tensor = fd_inter_tensor
|
||||||
|
|
||||||
|
def _init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None):
|
||||||
|
"""
|
||||||
|
Shard model or/and Load weight
|
||||||
|
|
||||||
|
Shard model: When we set tp_size > 1, we will shard the model by given model_policy.
|
||||||
|
Load Weight: If we pass a local model path, we will load the model weight by checkpoint_io. If it is a remote-transformer url, we will use `AutoModel.from_pretrained` api of transformers lib
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
|
||||||
|
model_policy (Policy): the policy to replace the model
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(model_or_path, str):
|
||||||
|
is_local = os.path.isdir(model_or_path)
|
||||||
|
try:
|
||||||
|
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
|
||||||
|
arch = getattr(hf_config, "architectures")[0]
|
||||||
|
if is_local:
|
||||||
|
model = _SUPPORTED_MODELS[arch](hf_config)
|
||||||
|
else:
|
||||||
|
# load the real checkpoint
|
||||||
|
model = _SUPPORTED_MODELS[arch].from_pretrained(model_or_path, trust_remote_code=True)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"An exception occurred during loading model: {e}, model should be loaded by transformers\n"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model = model_or_path
|
||||||
|
|
||||||
|
self.model_config = model.config
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
init_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||||
|
|
||||||
|
self.device = get_accelerator().get_current_device()
|
||||||
|
torch.cuda.set_device(self.device)
|
||||||
|
if self.verbose:
|
||||||
|
logger.info(f"the device is {self.device}")
|
||||||
|
|
||||||
|
model = model.to(dtype=self.dtype, non_blocking=False).eval()
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
logger.info(
|
||||||
|
f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_policy is None:
|
||||||
|
if self.inference_config.pad_input:
|
||||||
|
model_type = "padding_" + self.model_config.model_type
|
||||||
|
else:
|
||||||
|
model_type = "nopadding_" + self.model_config.model_type
|
||||||
|
model_policy = model_policy_map[model_type]()
|
||||||
|
|
||||||
|
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
|
||||||
|
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
|
||||||
|
|
||||||
|
self.model = self._shardformer(
|
||||||
|
model,
|
||||||
|
model_policy,
|
||||||
|
None,
|
||||||
|
tp_group=tp_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.model = ModelWrapper(model).to(device=get_accelerator().get_current_device())
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
logger.info(
|
||||||
|
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(model_or_path, str) and is_local:
|
||||||
|
from colossalai.inference.core.plugin import InferCheckpoint_io
|
||||||
|
|
||||||
|
cpt_io = InferCheckpoint_io()
|
||||||
|
if_has_index_file, model_index_file = has_index_file(model_or_path)
|
||||||
|
assert if_has_index_file, "the model path is invalid"
|
||||||
|
cpt_io.load_model(self.model, model_index_file)
|
||||||
|
|
||||||
|
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
|
||||||
|
peak_memory = init_gpu_memory - free_gpu_memory
|
||||||
|
if self.verbose:
|
||||||
|
logger.info(
|
||||||
|
f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _shardformer(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
model_policy: Policy,
|
||||||
|
stage_manager: PipelineStageManager = None,
|
||||||
|
tp_group: ProcessGroupMesh = None,
|
||||||
|
) -> nn.Module:
|
||||||
|
"""
|
||||||
|
Initialize ShardConfig and replace the model with shardformer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): Path or nn.Module of this model.
|
||||||
|
model_policy (Policy): The policy to shardformer model which is determined by the model type.
|
||||||
|
stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None.
|
||||||
|
tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
nn.Module: The model optimized by Shardformer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
shardconfig = ShardConfig(
|
||||||
|
tensor_parallel_process_group=tp_group,
|
||||||
|
pipeline_stage_manager=stage_manager,
|
||||||
|
enable_tensor_parallelism=(self.inference_config.tp_size > 1),
|
||||||
|
enable_fused_normalization=False,
|
||||||
|
enable_all_optimization=False,
|
||||||
|
enable_flash_attention=False,
|
||||||
|
enable_jit_fused=False,
|
||||||
|
enable_sequence_parallelism=False,
|
||||||
|
)
|
||||||
|
shardformer = ShardFormer(shard_config=shardconfig)
|
||||||
|
shard_model, _ = shardformer.optimize(model, model_policy)
|
||||||
|
return shard_model
|
||||||
|
|
||||||
|
def exposed_compute_only_for_test(self):
|
||||||
|
dist_rank = dist.get_rank()
|
||||||
|
|
||||||
|
# Dummy data for each worker
|
||||||
|
data = torch.tensor([dist_rank], dtype=torch.float).cuda(dist_rank)
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
# Perform distributed all_reduce
|
||||||
|
dist.all_reduce(data, op=dist.ReduceOp.SUM)
|
||||||
|
|
||||||
|
dist.barrier()
|
||||||
|
logger.info(f"Worker rank {dist_rank}: Sum after all_reduce: {data.item()}")
|
||||||
|
|
||||||
|
return data.item()
|
|
@ -1,4 +1,4 @@
|
||||||
from .block_cache import CacheBlock
|
from .block_cache import CacheBlock
|
||||||
from .kvcache_manager import KVCacheManager
|
from .kvcache_manager import KVCacheManager, RPCKVCacheManager
|
||||||
|
|
||||||
__all__ = ["CacheBlock", "KVCacheManager"]
|
__all__ = ["CacheBlock", "KVCacheManager", "RPCKVCacheManager"]
|
||||||
|
|
|
@ -497,3 +497,80 @@ class KVCacheManager:
|
||||||
k_cache.append(torch.zeros(kalloc_shape, dtype=self.kv_cache_dtype, device=self.device))
|
k_cache.append(torch.zeros(kalloc_shape, dtype=self.kv_cache_dtype, device=self.device))
|
||||||
v_cache.append(torch.zeros(valloc_shape, dtype=self.kv_cache_dtype, device=self.device))
|
v_cache.append(torch.zeros(valloc_shape, dtype=self.kv_cache_dtype, device=self.device))
|
||||||
return k_cache, v_cache
|
return k_cache, v_cache
|
||||||
|
|
||||||
|
|
||||||
|
class RPCKVCacheManager(KVCacheManager):
|
||||||
|
def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verbose: bool = False) -> None:
|
||||||
|
self.logger = get_dist_logger(__name__)
|
||||||
|
self.device = get_current_device()
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
# Parallel settings
|
||||||
|
self.tp_size = config.tp_size
|
||||||
|
# Model settings
|
||||||
|
self.dtype = config.dtype
|
||||||
|
self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size()
|
||||||
|
self.num_layers = model_config.num_hidden_layers
|
||||||
|
self.head_num = model_config.num_attention_heads
|
||||||
|
self.head_size = model_config.hidden_size // self.head_num
|
||||||
|
if hasattr(model_config, "num_key_value_heads"):
|
||||||
|
self.kv_head_num = model_config.num_key_value_heads
|
||||||
|
else:
|
||||||
|
self.kv_head_num = self.head_num
|
||||||
|
|
||||||
|
if config.kv_cache_dtype is None:
|
||||||
|
self.kv_cache_dtype = config.dtype
|
||||||
|
else:
|
||||||
|
self.kv_cache_dtype = config.kv_cache_dtype
|
||||||
|
|
||||||
|
assert (
|
||||||
|
self.kv_head_num % self.tp_size == 0
|
||||||
|
), f"Cannot shard {self.kv_head_num} heads with tp size {self.tp_size}"
|
||||||
|
self.kv_head_num //= self.tp_size
|
||||||
|
self.beam_width = config.beam_width
|
||||||
|
self.max_batch_size = config.max_batch_size
|
||||||
|
self.max_input_length = config.max_input_len
|
||||||
|
self.max_output_length = config.max_output_len
|
||||||
|
# Cache block settings
|
||||||
|
self.block_size = config.block_size
|
||||||
|
# NOTE: `num_blocks` is not prompted, but evaluated from the maximum input/output length, and the maximum batch size
|
||||||
|
self.max_blocks_per_sequence = (
|
||||||
|
self.max_input_length + self.max_output_length + self.block_size - 1
|
||||||
|
) // self.block_size
|
||||||
|
self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width
|
||||||
|
|
||||||
|
# Logical cache blocks allocation
|
||||||
|
self._available_blocks = self.num_blocks
|
||||||
|
self._cache_blocks = tuple(self._init_logical_caches())
|
||||||
|
# block availablity state 0->allocated, 1->free
|
||||||
|
self._block_states = torch.ones((self.num_blocks,), dtype=torch.bool)
|
||||||
|
self._block_states_cum = torch.zeros(size=(self.num_blocks + 1,), dtype=torch.int64)
|
||||||
|
self._block_finder = torch.zeros((self.num_blocks,), dtype=torch.int64)
|
||||||
|
|
||||||
|
def get_physical_cache_shape(self) -> Tuple[Tuple[int, ...], Tuple[int, ...]]:
|
||||||
|
# Physical cache allocation
|
||||||
|
if self.config.use_cuda_kernel:
|
||||||
|
x = 16 // torch.tensor([], dtype=self.config.dtype).element_size()
|
||||||
|
kalloc_shape = (self.num_blocks, self.kv_head_num, self.head_size // x, self.block_size, x)
|
||||||
|
valloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size)
|
||||||
|
self.logger.info(
|
||||||
|
f"Allocating K cache with shape: {kalloc_shape}, V cache with shape: {valloc_shape} consisting of {self.num_blocks} blocks."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
alloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size)
|
||||||
|
kalloc_shape = alloc_shape
|
||||||
|
valloc_shape = alloc_shape
|
||||||
|
self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
|
||||||
|
return kalloc_shape, valloc_shape
|
||||||
|
|
||||||
|
def get_kv_cache(self):
|
||||||
|
"""Get k_cache and v_cache"""
|
||||||
|
return NotImplementedError
|
||||||
|
|
||||||
|
def _init_logical_caches(self):
|
||||||
|
"""Initialize the logical cache blocks."""
|
||||||
|
blocks = []
|
||||||
|
for i in range(self.num_blocks):
|
||||||
|
cache_block = CacheBlock(i, self.block_size, self.elem_size_in_bytes, k_ptrs=None, v_ptrs=None)
|
||||||
|
blocks.append(cache_block)
|
||||||
|
return blocks
|
||||||
|
|
|
@ -1,10 +1,9 @@
|
||||||
# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/generation/logits_process.py
|
# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/generation/logits_process.py
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from colossalai.inference.batch_bucket import BatchBucket
|
|
||||||
|
|
||||||
_LOGIT_PROCESSOR_MAP = {}
|
_LOGIT_PROCESSOR_MAP = {}
|
||||||
|
|
||||||
|
|
||||||
|
@ -22,7 +21,7 @@ def register_logit_processor(process_type):
|
||||||
|
|
||||||
|
|
||||||
@register_logit_processor("no_repeat_ngram_size")
|
@register_logit_processor("no_repeat_ngram_size")
|
||||||
def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch: BatchBucket):
|
def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch_token_ids: List[List[int]]):
|
||||||
"""
|
"""
|
||||||
enforces no repetition of n-grams to avoid repetitions of word sequences.
|
enforces no repetition of n-grams to avoid repetitions of word sequences.
|
||||||
"""
|
"""
|
||||||
|
@ -31,7 +30,6 @@ def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch: BatchBuck
|
||||||
raise ValueError(f"'temperature={ngram_size}' should be a strictly positive integer.")
|
raise ValueError(f"'temperature={ngram_size}' should be a strictly positive integer.")
|
||||||
|
|
||||||
if ngram_size != 0:
|
if ngram_size != 0:
|
||||||
batch_token_ids = batch.batch_token_ids
|
|
||||||
batch_size = len(batch_token_ids)
|
batch_size = len(batch_token_ids)
|
||||||
|
|
||||||
for batch_id in range(batch_size):
|
for batch_id in range(batch_size):
|
||||||
|
@ -55,7 +53,7 @@ def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch: BatchBuck
|
||||||
|
|
||||||
|
|
||||||
@register_logit_processor("repetition_penalty")
|
@register_logit_processor("repetition_penalty")
|
||||||
def repetition_penalty_logit_process(logits, penalty: float, batch: BatchBucket):
|
def repetition_penalty_logit_process(logits, penalty: float, batch_token_ids: List[List[int]]):
|
||||||
"""
|
"""
|
||||||
apply the penalty to the tokens present in the prompt.
|
apply the penalty to the tokens present in the prompt.
|
||||||
"""
|
"""
|
||||||
|
@ -67,7 +65,6 @@ def repetition_penalty_logit_process(logits, penalty: float, batch: BatchBucket)
|
||||||
|
|
||||||
# TODO(yuehuayingxueluo) This is only a temporary implementation. Later, we will implement presence_penalties, frequency_penalties, and repetition_penalties using CUDA kernels.
|
# TODO(yuehuayingxueluo) This is only a temporary implementation. Later, we will implement presence_penalties, frequency_penalties, and repetition_penalties using CUDA kernels.
|
||||||
if penalty != 1.0:
|
if penalty != 1.0:
|
||||||
batch_token_ids = batch.batch_token_ids
|
|
||||||
for batch_id in range(len(batch_token_ids)):
|
for batch_id in range(len(batch_token_ids)):
|
||||||
current_logit = logits[batch_id]
|
current_logit = logits[batch_id]
|
||||||
current_token = torch.tensor(batch_token_ids[batch_id], dtype=torch.long, device=logits.device)
|
current_token = torch.tensor(batch_token_ids[batch_id], dtype=torch.long, device=logits.device)
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from colossalai.inference.config import RPC_PARAM
|
||||||
from colossalai.inference.modeling.layers.baichuan_tp_linear import (
|
from colossalai.inference.modeling.layers.baichuan_tp_linear import (
|
||||||
BaichuanLMHeadLinear1D_Col,
|
BaichuanLMHeadLinear1D_Col,
|
||||||
BaichuanWpackLinear1D_Col,
|
BaichuanWpackLinear1D_Col,
|
||||||
|
@ -18,7 +19,7 @@ from colossalai.shardformer.policies.base_policy import ModulePolicyDescription,
|
||||||
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
||||||
|
|
||||||
|
|
||||||
class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy):
|
class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -100,3 +101,10 @@ class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy):
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
init_to_get_rotary(self.model.model)
|
init_to_get_rotary(self.model.model)
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
|
def to_rpc_param(self) -> str:
|
||||||
|
return __class__.__name__
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_rpc_param() -> "NoPaddingBaichuanModelInferPolicy":
|
||||||
|
return NoPaddingBaichuanModelInferPolicy()
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm
|
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm
|
||||||
|
|
||||||
|
from colossalai.inference.config import RPC_PARAM
|
||||||
from colossalai.inference.modeling.models.nopadding_llama import (
|
from colossalai.inference.modeling.models.nopadding_llama import (
|
||||||
NopadLlamaAttention,
|
NopadLlamaAttention,
|
||||||
NopadLlamaMLP,
|
NopadLlamaMLP,
|
||||||
|
@ -14,7 +15,7 @@ from colossalai.shardformer.policies.base_policy import ModulePolicyDescription,
|
||||||
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
||||||
|
|
||||||
|
|
||||||
class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -102,3 +103,10 @@ class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
init_to_get_rotary(self.model.model, self.model.config.rope_theta)
|
init_to_get_rotary(self.model.model, self.model.config.rope_theta)
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
|
def to_rpc_param(self) -> str:
|
||||||
|
return __class__.__name__
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_rpc_param() -> "NoPaddingLlamaModelInferPolicy":
|
||||||
|
return NoPaddingLlamaModelInferPolicy()
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
from typing import List, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from transformers.generation import GenerationConfig
|
||||||
|
|
||||||
|
from colossalai.inference.logit_processors import logit_processor
|
||||||
|
|
||||||
|
|
||||||
def greedy_sample(
|
def greedy_sample(
|
||||||
|
@ -59,3 +62,47 @@ def beam_search_sample(
|
||||||
|
|
||||||
results.append((next_token_ids, parent_ids))
|
results.append((next_token_ids, parent_ids))
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def _sample(probs: torch.Tensor, logprobs: torch.Tensor, generation_config: GenerationConfig, is_prompt: bool = False):
|
||||||
|
if generation_config.num_beams == 1:
|
||||||
|
if generation_config.do_sample:
|
||||||
|
sample_tokens = multinomial_sample(generation_config, probs)
|
||||||
|
else:
|
||||||
|
sample_tokens = greedy_sample(generation_config, logprobs)
|
||||||
|
else:
|
||||||
|
sample_tokens = beam_search_sample(generation_config, logprobs, is_prompt=is_prompt)
|
||||||
|
|
||||||
|
return sample_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def search_tokens(
|
||||||
|
generation_config: GenerationConfig,
|
||||||
|
logits,
|
||||||
|
is_prompt: bool = False,
|
||||||
|
batch_token_ids: Optional[List[List[int]]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Sample tokens for finished requests.
|
||||||
|
"""
|
||||||
|
# NOTE: need to decide the granularity to process logits (sequence or batch)
|
||||||
|
config_dict = generation_config.to_dict()
|
||||||
|
# process repetition_penalty, no_repeat_ngram_size
|
||||||
|
for type in ["repetition_penalty", "no_repeat_ngram_size"]:
|
||||||
|
if type in config_dict and config_dict[type] is not None:
|
||||||
|
logits = logit_processor(type, logits, config_dict[type], batch_token_ids)
|
||||||
|
|
||||||
|
# do logit processor
|
||||||
|
if generation_config.do_sample:
|
||||||
|
# process temperature, top_k, top_p
|
||||||
|
for type in ["temperature", "top_k", "top_p"]:
|
||||||
|
if type in config_dict and config_dict[type] is not None:
|
||||||
|
logits = logit_processor(type, logits, config_dict[type])
|
||||||
|
|
||||||
|
# calculate probs
|
||||||
|
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||||
|
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
||||||
|
|
||||||
|
# sample the next tokens
|
||||||
|
sample_tokens = _sample(probs, logprobs, generation_config, is_prompt)
|
||||||
|
return sample_tokens
|
||||||
|
|
|
@ -9,6 +9,8 @@ from typing import Optional, Tuple
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from colossalai.testing import free_port
|
||||||
|
|
||||||
|
|
||||||
def init_to_get_rotary(self, base=10000, use_elem=False):
|
def init_to_get_rotary(self, base=10000, use_elem=False):
|
||||||
"""
|
"""
|
||||||
|
@ -102,3 +104,12 @@ def get_model_size(model: nn.Module):
|
||||||
for key, param in model.named_parameters():
|
for key, param in model.named_parameters():
|
||||||
total_size += param.element_size() * param.numel()
|
total_size += param.element_size() * param.numel()
|
||||||
return total_size / (1024**3)
|
return total_size / (1024**3)
|
||||||
|
|
||||||
|
|
||||||
|
def find_available_ports(num: int):
|
||||||
|
try:
|
||||||
|
free_ports = [free_port() for i in range(num)]
|
||||||
|
except OSError as e:
|
||||||
|
print(f"An OS error occurred: {e}")
|
||||||
|
raise RuntimeError("Error finding available ports")
|
||||||
|
return free_ports
|
||||||
|
|
|
@ -19,4 +19,5 @@ datasets
|
||||||
pydantic
|
pydantic
|
||||||
ray
|
ray
|
||||||
peft>=0.7.1
|
peft>=0.7.1
|
||||||
|
rpyc==6.0.0
|
||||||
#auto-gptq now not support torch1.12
|
#auto-gptq now not support torch1.12
|
||||||
|
|
|
@ -19,3 +19,4 @@ protobuf
|
||||||
transformers==4.36.2
|
transformers==4.36.2
|
||||||
peft>=0.7.1
|
peft>=0.7.1
|
||||||
bitsandbytes>=0.39.0
|
bitsandbytes>=0.39.0
|
||||||
|
rpyc==6.0.0
|
||||||
|
|
|
@ -0,0 +1,105 @@
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
||||||
|
|
||||||
|
from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig
|
||||||
|
from colossalai.inference.core.rpc_engine import RPCInferenceEngine
|
||||||
|
from colossalai.inference.modeling.policy import NoPaddingLlamaModelInferPolicy
|
||||||
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||||
|
|
||||||
|
|
||||||
|
def setup_seed(seed):
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.random.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def check_inference_engine(tp_size, use_engine=False, prompt_template=None, do_sample=True, policy=None):
|
||||||
|
setup_seed(20)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||||
|
model = "meta-llama/Llama-2-7b-hf" # remote mode path
|
||||||
|
inputs = [
|
||||||
|
"介绍一下今天的北京,比如故宫,天安门,长城或者其他的一些景点,",
|
||||||
|
"介绍一下武汉,",
|
||||||
|
]
|
||||||
|
|
||||||
|
output_len = 38
|
||||||
|
top_p = 0.5
|
||||||
|
top_k = 50
|
||||||
|
|
||||||
|
if use_engine:
|
||||||
|
inference_config = InferenceConfig(
|
||||||
|
max_output_len=output_len,
|
||||||
|
prompt_template=prompt_template,
|
||||||
|
dtype="fp32",
|
||||||
|
use_cuda_kernel=True,
|
||||||
|
tp_size=tp_size,
|
||||||
|
)
|
||||||
|
inference_engine = RPCInferenceEngine(model, tokenizer, inference_config, verbose=True, model_policy=policy)
|
||||||
|
assert inference_engine.generation_config.max_new_tokens == output_len
|
||||||
|
inference_engine.add_request(prompts=inputs)
|
||||||
|
assert inference_engine.request_handler._has_waiting()
|
||||||
|
generation_config = GenerationConfig(
|
||||||
|
max_new_tokens=output_len, do_sample=do_sample, dtype="fp32", top_p=top_p, top_k=top_k
|
||||||
|
)
|
||||||
|
outputs = inference_engine.generate(generation_config=generation_config)
|
||||||
|
else:
|
||||||
|
if prompt_template:
|
||||||
|
# apply prompt template
|
||||||
|
inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs]
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model).cuda()
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
|
inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
|
||||||
|
inputs = inputs.cuda()
|
||||||
|
generation_config = GenerationConfig(
|
||||||
|
do_sample=do_sample,
|
||||||
|
dtype="fp32",
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
max_new_tokens=output_len,
|
||||||
|
)
|
||||||
|
outputs = model.generate(inputs, generation_config=generation_config)
|
||||||
|
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
def run_engine(tp_size, **kwargs):
|
||||||
|
return check_inference_engine(tp_size=tp_size, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.largedist
|
||||||
|
@parameterize("prompt_template", [None, "llama"])
|
||||||
|
@parameterize("do_sample", [False])
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_tp_engine(prompt_template, do_sample):
|
||||||
|
if torch.multiprocessing.get_start_method(allow_none=True) is None:
|
||||||
|
torch.multiprocessing.set_start_method("spawn")
|
||||||
|
kwargs1 = {
|
||||||
|
"use_engine": True,
|
||||||
|
"prompt_template": prompt_template,
|
||||||
|
"do_sample": do_sample,
|
||||||
|
"policy": NoPaddingLlamaModelInferPolicy(),
|
||||||
|
}
|
||||||
|
|
||||||
|
kwargs2 = {"use_engine": False, "prompt_template": prompt_template, "do_sample": do_sample, "policy": None}
|
||||||
|
|
||||||
|
colossal_tp_1_output = run_engine(1, **kwargs1)
|
||||||
|
colossal_tp_2_output = run_engine(2, **kwargs1)
|
||||||
|
transformer_tp_1_output = run_engine(1, **kwargs2)
|
||||||
|
|
||||||
|
for s1, s2, s3 in zip(colossal_tp_1_output, colossal_tp_2_output, transformer_tp_1_output):
|
||||||
|
assert s1 == s3, f"\nColossalAI TP=1 Output: {s1}\nTransformers Output: {s3}"
|
||||||
|
assert s1 == s2, f"\nColossalAI TP=1 Output: {s1}\nColossalAI TP=2 Output: {s2}"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
torch.multiprocessing.set_start_method("spawn") # this code will not be ok for settings to fork to subprocess
|
||||||
|
test_tp_engine()
|
Loading…
Reference in New Issue