mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
195 lines
6.3 KiB
195 lines
6.3 KiB
""" |
|
API and LLM warpper class for running LLMs locally |
|
|
|
Usage: |
|
|
|
import os |
|
model_path = os.environ.get("ZH_MODEL_PATH") |
|
model_name = "chatglm2" |
|
colossal_api = ColossalAPI(model_name, model_path) |
|
llm = ColossalLLM(n=1, api=colossal_api) |
|
TEST_PROMPT_CHATGLM="续写文章:惊蛰一过,春寒加剧。先是料料峭峭,继而雨季开始," |
|
logger.info(llm(TEST_PROMPT_CHATGLM, max_new_tokens=100), verbose=True) |
|
|
|
""" |
|
from typing import Any, List, Mapping, Optional |
|
|
|
import torch |
|
from colossalqa.local.utils import get_response, post_http_request |
|
from colossalqa.mylogging import get_logger |
|
from langchain.callbacks.manager import CallbackManagerForLLMRun |
|
from langchain.llms.base import LLM |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
logger = get_logger() |
|
|
|
|
|
class ColossalAPI: |
|
""" |
|
API for calling LLM.generate |
|
""" |
|
|
|
__instances = dict() |
|
|
|
def __init__(self, model_type: str, model_path: str, ckpt_path: str = None) -> None: |
|
""" |
|
Configure model |
|
""" |
|
if model_type + model_path + (ckpt_path or "") in ColossalAPI.__instances: |
|
return |
|
else: |
|
ColossalAPI.__instances[model_type + model_path + (ckpt_path or "")] = self |
|
self.model_type = model_type |
|
self.model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True) |
|
|
|
if ckpt_path is not None: |
|
state_dict = torch.load(ckpt_path) |
|
self.model.load_state_dict(state_dict) |
|
self.model.to(torch.cuda.current_device()) |
|
|
|
# Configure tokenizer |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
|
|
|
self.model.eval() |
|
|
|
@staticmethod |
|
def get_api(model_type: str, model_path: str, ckpt_path: str = None): |
|
if model_type + model_path + (ckpt_path or "") in ColossalAPI.__instances: |
|
return ColossalAPI.__instances[model_type + model_path + (ckpt_path or "")] |
|
else: |
|
return ColossalAPI(model_type, model_path, ckpt_path) |
|
|
|
def generate(self, input: str, **kwargs) -> str: |
|
""" |
|
Generate response given the prompt |
|
Args: |
|
input: input string |
|
**kwargs: language model keyword type arguments, such as top_k, top_p, temperature, max_new_tokens... |
|
Returns: |
|
output: output string |
|
""" |
|
if self.model_type in ["chatglm", "chatglm2"]: |
|
inputs = { |
|
k: v.to(torch.cuda.current_device()) for k, v in self.tokenizer(input, return_tensors="pt").items() |
|
} |
|
else: |
|
inputs = { |
|
"input_ids": self.tokenizer(input, return_tensors="pt")["input_ids"].to(torch.cuda.current_device()) |
|
} |
|
|
|
output = self.model.generate(**inputs, **kwargs) |
|
output = output.cpu() |
|
prompt_len = inputs["input_ids"].size(1) |
|
response = output[0, prompt_len:] |
|
output = self.tokenizer.decode(response, skip_special_tokens=True) |
|
return output |
|
|
|
|
|
class VllmAPI: |
|
def __init__(self, host: str = "localhost", port: int = 8077) -> None: |
|
# Configure api for model served through web |
|
self.host = host |
|
self.port = port |
|
self.url = f"http://{self.host}:{self.port}/generate" |
|
|
|
def generate(self, input: str, **kwargs): |
|
output = get_response(post_http_request(input, self.url, **kwargs))[0] |
|
return output[len(input) :] |
|
|
|
|
|
class ColossalLLM(LLM): |
|
""" |
|
Langchain LLM wrapper for a local LLM |
|
""" |
|
|
|
n: int |
|
api: Any |
|
kwargs = {"max_new_tokens": 100} |
|
|
|
@property |
|
def _llm_type(self) -> str: |
|
return "custom" |
|
|
|
def _call( |
|
self, |
|
prompt: str, |
|
stop: Optional[List[str]] = None, |
|
run_manager: Optional[CallbackManagerForLLMRun] = None, |
|
**kwargs: Any, |
|
) -> str: |
|
logger.info(f"kwargs:{kwargs}\nstop:{stop}\nprompt:{prompt}", verbose=self.verbose) |
|
for k in self.kwargs: |
|
if k not in kwargs: |
|
kwargs[k] = self.kwargs[k] |
|
|
|
generate_args = {k: kwargs[k] for k in kwargs if k not in ["stop", "n"]} |
|
out = self.api.generate(prompt, **generate_args) |
|
if isinstance(stop, list) and len(stop) != 0: |
|
for stopping_words in stop: |
|
if stopping_words in out: |
|
out = out.split(stopping_words)[0] |
|
logger.info(f"{prompt}{out}", verbose=self.verbose) |
|
return out |
|
|
|
@property |
|
def _identifying_params(self) -> Mapping[str, int]: |
|
"""Get the identifying parameters.""" |
|
return {"n": self.n} |
|
|
|
def get_token_ids(self, text: str) -> List[int]: |
|
"""Return the ordered ids of the tokens in a text. |
|
|
|
Args: |
|
text: The string input to tokenize. |
|
|
|
Returns: |
|
A list of ids corresponding to the tokens in the text, in order they occur |
|
in the text. |
|
""" |
|
# use the colossal llm's tokenizer instead of langchain's cached GPT2 tokenizer |
|
return self.api.tokenizer.encode(text) |
|
|
|
|
|
class VllmLLM(LLM): |
|
""" |
|
Langchain LLM wrapper for a local LLM |
|
""" |
|
|
|
n: int |
|
api: Any |
|
kwargs = {"max_new_tokens": 100} |
|
|
|
@property |
|
def _llm_type(self) -> str: |
|
return "custom" |
|
|
|
def _call( |
|
self, |
|
prompt: str, |
|
stop: Optional[List[str]] = None, |
|
run_manager: Optional[CallbackManagerForLLMRun] = None, |
|
**kwargs: Any, |
|
) -> str: |
|
for k in self.kwargs: |
|
if k not in kwargs: |
|
kwargs[k] = self.kwargs[k] |
|
logger.info(f"kwargs:{kwargs}\nstop:{stop}\nprompt:{prompt}", verbose=self.verbose) |
|
generate_args = {k: kwargs[k] for k in kwargs if k in ["n", "max_tokens", "temperature", "stream"]} |
|
out = self.api.generate(prompt, **generate_args) |
|
if len(stop) != 0: |
|
for stopping_words in stop: |
|
if stopping_words in out: |
|
out = out.split(stopping_words)[0] |
|
logger.info(f"{prompt}{out}", verbose=self.verbose) |
|
return out |
|
|
|
def set_host_port(self, host: str = "localhost", port: int = 8077, **kwargs) -> None: |
|
if "max_tokens" not in kwargs: |
|
kwargs["max_tokens"] = 100 |
|
self.kwargs = kwargs |
|
self.api = VllmAPI(host=host, port=port) |
|
|
|
@property |
|
def _identifying_params(self) -> Mapping[str, int]: |
|
"""Get the identifying parameters.""" |
|
return {"n": self.n}
|
|
|