ColossalAI/applications/Chat/inference/server.py

184 lines
7.1 KiB
Python
Raw Normal View History

2023-03-28 12:25:36 +00:00
import argparse
import os
from threading import Lock
[chat] fix bugs and add unit tests (#4213) * style: rename replay buffer Experience replay is typically for off policy algorithms. Use this name in PPO maybe misleading. * fix: fix wrong zero2 default arg * test: update experience tests * style: rename zero_pad fn * fix: defer init in CycledDataLoader * test: add benchmark test * style: rename internal fn of generation * style: rename internal fn of lora * fix: remove unused loss fn * fix: remove unused utils fn * refactor: remove generate_with_actor fn * fix: fix type annotation * test: add models tests * fix: skip llama due to long execution time * style: modify dataset * style: apply formatter * perf: update reward dataset * fix: fix wrong IGNORE_INDEX in sft dataset * fix: remove DataCollatorForSupervisedDataset * test: add dataset tests * style: apply formatter * style: rename test_ci to test_train * feat: add llama in inference * test: add inference tests * test: change test scripts directory * fix: update ci * fix: fix typo * fix: skip llama due to oom * fix: fix file mod * style: apply formatter * refactor: remove duplicated llama_gptq * style: apply formatter * to: update rm test * feat: add tokenizer arg * feat: add download model script * test: update train tests * fix: modify gemini load and save pretrained * test: update checkpoint io test * to: modify nproc_per_node * fix: do not remove existing dir * fix: modify save path * test: add random choice * fix: fix sft path * fix: enlarge nproc_per_node to avoid oom * fix: add num_retry * fix: make lora config of rm and critic consistent * fix: add warning about lora weights * fix: skip some gpt2 tests * fix: remove grad ckpt in rm and critic due to errors * refactor: directly use Actor in train_sft * test: add more arguments * fix: disable grad ckpt when using lora * fix: fix save_pretrained and related tests * test: enable zero2 tests * revert: remove useless fn * style: polish code * test: modify test args
2023-08-02 02:17:36 +00:00
from typing import Generator, List, Optional
2023-03-28 12:25:36 +00:00
import torch
import uvicorn
[chat] fix bugs and add unit tests (#4213) * style: rename replay buffer Experience replay is typically for off policy algorithms. Use this name in PPO maybe misleading. * fix: fix wrong zero2 default arg * test: update experience tests * style: rename zero_pad fn * fix: defer init in CycledDataLoader * test: add benchmark test * style: rename internal fn of generation * style: rename internal fn of lora * fix: remove unused loss fn * fix: remove unused utils fn * refactor: remove generate_with_actor fn * fix: fix type annotation * test: add models tests * fix: skip llama due to long execution time * style: modify dataset * style: apply formatter * perf: update reward dataset * fix: fix wrong IGNORE_INDEX in sft dataset * fix: remove DataCollatorForSupervisedDataset * test: add dataset tests * style: apply formatter * style: rename test_ci to test_train * feat: add llama in inference * test: add inference tests * test: change test scripts directory * fix: update ci * fix: fix typo * fix: skip llama due to oom * fix: fix file mod * style: apply formatter * refactor: remove duplicated llama_gptq * style: apply formatter * to: update rm test * feat: add tokenizer arg * feat: add download model script * test: update train tests * fix: modify gemini load and save pretrained * test: update checkpoint io test * to: modify nproc_per_node * fix: do not remove existing dir * fix: modify save path * test: add random choice * fix: fix sft path * fix: enlarge nproc_per_node to avoid oom * fix: add num_retry * fix: make lora config of rm and critic consistent * fix: add warning about lora weights * fix: skip some gpt2 tests * fix: remove grad ckpt in rm and critic due to errors * refactor: directly use Actor in train_sft * test: add more arguments * fix: disable grad ckpt when using lora * fix: fix save_pretrained and related tests * test: enable zero2 tests * revert: remove useless fn * style: polish code * test: modify test args
2023-08-02 02:17:36 +00:00
from coati.quant import llama_load_quant, low_resource_init
from fastapi import FastAPI, Request
2023-03-28 12:25:36 +00:00
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address
from sse_starlette.sse import EventSourceResponse
[chat] fix bugs and add unit tests (#4213) * style: rename replay buffer Experience replay is typically for off policy algorithms. Use this name in PPO maybe misleading. * fix: fix wrong zero2 default arg * test: update experience tests * style: rename zero_pad fn * fix: defer init in CycledDataLoader * test: add benchmark test * style: rename internal fn of generation * style: rename internal fn of lora * fix: remove unused loss fn * fix: remove unused utils fn * refactor: remove generate_with_actor fn * fix: fix type annotation * test: add models tests * fix: skip llama due to long execution time * style: modify dataset * style: apply formatter * perf: update reward dataset * fix: fix wrong IGNORE_INDEX in sft dataset * fix: remove DataCollatorForSupervisedDataset * test: add dataset tests * style: apply formatter * style: rename test_ci to test_train * feat: add llama in inference * test: add inference tests * test: change test scripts directory * fix: update ci * fix: fix typo * fix: skip llama due to oom * fix: fix file mod * style: apply formatter * refactor: remove duplicated llama_gptq * style: apply formatter * to: update rm test * feat: add tokenizer arg * feat: add download model script * test: update train tests * fix: modify gemini load and save pretrained * test: update checkpoint io test * to: modify nproc_per_node * fix: do not remove existing dir * fix: modify save path * test: add random choice * fix: fix sft path * fix: enlarge nproc_per_node to avoid oom * fix: add num_retry * fix: make lora config of rm and critic consistent * fix: add warning about lora weights * fix: skip some gpt2 tests * fix: remove grad ckpt in rm and critic due to errors * refactor: directly use Actor in train_sft * test: add more arguments * fix: disable grad ckpt when using lora * fix: fix save_pretrained and related tests * test: enable zero2 tests * revert: remove useless fn * style: polish code * test: modify test args
2023-08-02 02:17:36 +00:00
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
from utils import ChatPromptProcessor, Dialogue, LockedIterator, load_json, sample_streamingly, update_model_kwargs_fn
2023-03-28 12:25:36 +00:00
CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.'
MAX_LEN = 512
2023-03-28 12:25:36 +00:00
running_lock = Lock()
class GenerationTaskReq(BaseModel):
max_new_tokens: int = Field(gt=0, le=512, example=64)
history: List[Dialogue] = Field(min_items=1)
top_k: Optional[int] = Field(default=None, gt=0, example=50)
top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5)
temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7)
repetition_penalty: Optional[float] = Field(default=None, gt=1.0, example=1.2)
2023-03-28 12:25:36 +00:00
limiter = Limiter(key_func=get_remote_address)
app = FastAPI()
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# set CORS
origin_spec_from_env = os.environ.get('CORS_ORIGIN', None)
if origin_spec_from_env is not None:
# allow CORS from the specified origins
origins = os.environ['CORS_ORIGIN'].split(',')
else:
# allow CORS from all origins
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def generate_streamingly(prompt, max_new_tokens, top_k, top_p, temperature):
inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()}
[chat] fix bugs and add unit tests (#4213) * style: rename replay buffer Experience replay is typically for off policy algorithms. Use this name in PPO maybe misleading. * fix: fix wrong zero2 default arg * test: update experience tests * style: rename zero_pad fn * fix: defer init in CycledDataLoader * test: add benchmark test * style: rename internal fn of generation * style: rename internal fn of lora * fix: remove unused loss fn * fix: remove unused utils fn * refactor: remove generate_with_actor fn * fix: fix type annotation * test: add models tests * fix: skip llama due to long execution time * style: modify dataset * style: apply formatter * perf: update reward dataset * fix: fix wrong IGNORE_INDEX in sft dataset * fix: remove DataCollatorForSupervisedDataset * test: add dataset tests * style: apply formatter * style: rename test_ci to test_train * feat: add llama in inference * test: add inference tests * test: change test scripts directory * fix: update ci * fix: fix typo * fix: skip llama due to oom * fix: fix file mod * style: apply formatter * refactor: remove duplicated llama_gptq * style: apply formatter * to: update rm test * feat: add tokenizer arg * feat: add download model script * test: update train tests * fix: modify gemini load and save pretrained * test: update checkpoint io test * to: modify nproc_per_node * fix: do not remove existing dir * fix: modify save path * test: add random choice * fix: fix sft path * fix: enlarge nproc_per_node to avoid oom * fix: add num_retry * fix: make lora config of rm and critic consistent * fix: add warning about lora weights * fix: skip some gpt2 tests * fix: remove grad ckpt in rm and critic due to errors * refactor: directly use Actor in train_sft * test: add more arguments * fix: disable grad ckpt when using lora * fix: fix save_pretrained and related tests * test: enable zero2 tests * revert: remove useless fn * style: polish code * test: modify test args
2023-08-02 02:17:36 +00:00
# TODO(ver217): streaming generation does not support repetition_penalty now
2023-03-28 12:25:36 +00:00
model_kwargs = {
'max_generate_tokens': max_new_tokens,
'early_stopping': True,
'top_k': top_k,
'top_p': top_p,
'temperature': temperature,
'prepare_inputs_fn': model.prepare_inputs_for_generation,
'update_model_kwargs_fn': update_model_kwargs_fn,
}
is_first_word = True
generator = LockedIterator(sample_streamingly(model, **inputs, **model_kwargs), running_lock)
for output in generator:
output = output.cpu()
tokens = tokenizer.convert_ids_to_tokens(output, skip_special_tokens=True)
current_sub_tokens = []
for token in tokens:
if token in tokenizer.all_special_tokens:
continue
current_sub_tokens.append(token)
if current_sub_tokens:
out_string = tokenizer.sp_model.decode(current_sub_tokens)
if is_first_word:
out_string = out_string.lstrip()
is_first_word = False
elif current_sub_tokens[0].startswith(''):
# whitespace will be ignored by the frontend
out_string = ' ' + out_string
yield out_string
async def event_generator(request: Request, generator: Generator):
while True:
if await request.is_disconnected():
break
try:
yield {'event': 'generate', 'data': next(generator)}
except StopIteration:
yield {'event': 'end', 'data': ''}
break
@app.post('/generate/stream')
@limiter.limit('1/second')
def generate(data: GenerationTaskReq, request: Request):
prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens)
event_source = event_generator(
request, generate_streamingly(prompt, data.max_new_tokens, data.top_k, data.top_p, data.temperature))
return EventSourceResponse(event_source)
@app.post('/generate')
@limiter.limit('1/second')
def generate_no_stream(data: GenerationTaskReq, request: Request):
prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens)
if prompt_processor.has_censored_words(prompt):
return prompt_processor.SAFE_RESPONSE
2023-03-28 12:25:36 +00:00
inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()}
with running_lock:
output = model.generate(**inputs, **data.dict(exclude={'history'}))
output = output.cpu()
prompt_len = inputs['input_ids'].size(1)
response = output[0, prompt_len:]
out_string = tokenizer.decode(response, skip_special_tokens=True)
out_string = prompt_processor.postprocess_output(out_string)
if prompt_processor.has_censored_words(out_string):
return prompt_processor.SAFE_RESPONSE
return out_string
2023-03-28 12:25:36 +00:00
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'pretrained',
help='Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.')
parser.add_argument('--quant',
choices=['8bit', '4bit'],
default=None,
help='Quantization mode. Default: None (no quantization, fp16).')
parser.add_argument(
'--gptq_checkpoint',
default=None,
help='Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.')
parser.add_argument('--gptq_group_size',
type=int,
default=128,
help='Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.')
parser.add_argument('--http_host', default='0.0.0.0')
parser.add_argument('--http_port', type=int, default=7070)
parser.add_argument('--profanity_file',
default=None,
help='Path to profanity words list. It should be a JSON file containing a list of words.')
2023-03-28 12:25:36 +00:00
args = parser.parse_args()
if args.quant == '4bit':
assert args.gptq_checkpoint is not None, 'Please specify a GPTQ checkpoint.'
tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
if args.profanity_file is not None:
censored_words = load_json(args.profanity_file)
else:
censored_words = []
prompt_processor = ChatPromptProcessor(tokenizer, CONTEXT, MAX_LEN, censored_words=censored_words)
2023-03-28 12:25:36 +00:00
if args.quant == '4bit':
[chat] fix bugs and add unit tests (#4213) * style: rename replay buffer Experience replay is typically for off policy algorithms. Use this name in PPO maybe misleading. * fix: fix wrong zero2 default arg * test: update experience tests * style: rename zero_pad fn * fix: defer init in CycledDataLoader * test: add benchmark test * style: rename internal fn of generation * style: rename internal fn of lora * fix: remove unused loss fn * fix: remove unused utils fn * refactor: remove generate_with_actor fn * fix: fix type annotation * test: add models tests * fix: skip llama due to long execution time * style: modify dataset * style: apply formatter * perf: update reward dataset * fix: fix wrong IGNORE_INDEX in sft dataset * fix: remove DataCollatorForSupervisedDataset * test: add dataset tests * style: apply formatter * style: rename test_ci to test_train * feat: add llama in inference * test: add inference tests * test: change test scripts directory * fix: update ci * fix: fix typo * fix: skip llama due to oom * fix: fix file mod * style: apply formatter * refactor: remove duplicated llama_gptq * style: apply formatter * to: update rm test * feat: add tokenizer arg * feat: add download model script * test: update train tests * fix: modify gemini load and save pretrained * test: update checkpoint io test * to: modify nproc_per_node * fix: do not remove existing dir * fix: modify save path * test: add random choice * fix: fix sft path * fix: enlarge nproc_per_node to avoid oom * fix: add num_retry * fix: make lora config of rm and critic consistent * fix: add warning about lora weights * fix: skip some gpt2 tests * fix: remove grad ckpt in rm and critic due to errors * refactor: directly use Actor in train_sft * test: add more arguments * fix: disable grad ckpt when using lora * fix: fix save_pretrained and related tests * test: enable zero2 tests * revert: remove useless fn * style: polish code * test: modify test args
2023-08-02 02:17:36 +00:00
with low_resource_init():
config = LlamaConfig.from_pretrained(args.pretrained)
model = LlamaForCausalLM(config)
model = llama_load_quant(model, args.gptq_checkpoint, 4, args.gptq_group_size)
2023-03-28 12:25:36 +00:00
model.cuda()
else:
model = LlamaForCausalLM.from_pretrained(
args.pretrained,
load_in_8bit=(args.quant == '8bit'),
torch_dtype=torch.float16,
device_map="auto",
)
if args.quant != '8bit':
model.half() # seems to fix bugs for some users.
model.eval()
config = uvicorn.Config(app, host=args.http_host, port=args.http_port)
server = uvicorn.Server(config=config)
server.run()