mirror of https://github.com/THUDM/ChatGLM2-6B
				
				
				
			fix: openai_api 的 stream api,服务端全部生成文本后客户端才一次性收到
							parent
							
								
									53f0106817
								
							
						
					
					
						commit
						fcd2d7f4bb
					
				| 
						 | 
				
			
			@ -11,9 +11,9 @@ from pydantic import BaseModel, Field
 | 
			
		|||
from fastapi import FastAPI, HTTPException
 | 
			
		||||
from fastapi.middleware.cors import CORSMiddleware
 | 
			
		||||
from contextlib import asynccontextmanager
 | 
			
		||||
from starlette.responses import StreamingResponse
 | 
			
		||||
from typing import Any, Dict, List, Literal, Optional, Union
 | 
			
		||||
from transformers import AutoTokenizer, AutoModel
 | 
			
		||||
from sse_starlette.sse import ServerSentEvent, EventSourceResponse
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@asynccontextmanager
 | 
			
		||||
| 
						 | 
				
			
			@ -114,7 +114,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
 | 
			
		|||
 | 
			
		||||
    if request.stream:
 | 
			
		||||
        generate = predict(query, history, request.model)
 | 
			
		||||
        return StreamingResponse(generate, media_type="text/event-stream")
 | 
			
		||||
        return EventSourceResponse(generate, media_type="text/event-stream")
 | 
			
		||||
 | 
			
		||||
    response, _ = model.chat(tokenizer, query, history=history)
 | 
			
		||||
    choice_data = ChatCompletionResponseChoice(
 | 
			
		||||
| 
						 | 
				
			
			@ -135,7 +135,7 @@ async def predict(query: str, history: List[List[str]], model_id: str):
 | 
			
		|||
        finish_reason=None
 | 
			
		||||
    )
 | 
			
		||||
    chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
 | 
			
		||||
    yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
 | 
			
		||||
    yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
 | 
			
		||||
 | 
			
		||||
    current_length = 0
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -152,7 +152,8 @@ async def predict(query: str, history: List[List[str]], model_id: str):
 | 
			
		|||
            finish_reason=None
 | 
			
		||||
        )
 | 
			
		||||
        chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
 | 
			
		||||
        yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
 | 
			
		||||
        yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    choice_data = ChatCompletionResponseStreamChoice(
 | 
			
		||||
        index=0,
 | 
			
		||||
| 
						 | 
				
			
			@ -160,7 +161,9 @@ async def predict(query: str, history: List[List[str]], model_id: str):
 | 
			
		|||
        finish_reason="stop"
 | 
			
		||||
    )
 | 
			
		||||
    chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
 | 
			
		||||
    yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
 | 
			
		||||
    yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
 | 
			
		||||
    yield '[DONE]'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -6,3 +6,4 @@ gradio
 | 
			
		|||
mdtex2html
 | 
			
		||||
sentencepiece
 | 
			
		||||
accelerate
 | 
			
		||||
sse-starlette
 | 
			
		||||
		Loading…
	
		Reference in New Issue