mirror of https://github.com/THUDM/ChatGLM2-6B
				
				
				
			fix: openai_api 的 stream api,服务端全部生成文本后客户端才一次性收到
							parent
							
								
									53f0106817
								
							
						
					
					
						commit
						fcd2d7f4bb
					
				|  | @ -11,9 +11,9 @@ from pydantic import BaseModel, Field | |||
| from fastapi import FastAPI, HTTPException | ||||
| from fastapi.middleware.cors import CORSMiddleware | ||||
| from contextlib import asynccontextmanager | ||||
| from starlette.responses import StreamingResponse | ||||
| from typing import Any, Dict, List, Literal, Optional, Union | ||||
| from transformers import AutoTokenizer, AutoModel | ||||
| from sse_starlette.sse import ServerSentEvent, EventSourceResponse | ||||
| 
 | ||||
| 
 | ||||
| @asynccontextmanager | ||||
|  | @ -114,7 +114,7 @@ async def create_chat_completion(request: ChatCompletionRequest): | |||
| 
 | ||||
|     if request.stream: | ||||
|         generate = predict(query, history, request.model) | ||||
|         return StreamingResponse(generate, media_type="text/event-stream") | ||||
|         return EventSourceResponse(generate, media_type="text/event-stream") | ||||
| 
 | ||||
|     response, _ = model.chat(tokenizer, query, history=history) | ||||
|     choice_data = ChatCompletionResponseChoice( | ||||
|  | @ -135,7 +135,7 @@ async def predict(query: str, history: List[List[str]], model_id: str): | |||
|         finish_reason=None | ||||
|     ) | ||||
|     chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") | ||||
|     yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False)) | ||||
|     yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) | ||||
| 
 | ||||
|     current_length = 0 | ||||
| 
 | ||||
|  | @ -152,7 +152,8 @@ async def predict(query: str, history: List[List[str]], model_id: str): | |||
|             finish_reason=None | ||||
|         ) | ||||
|         chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") | ||||
|         yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False)) | ||||
|         yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) | ||||
| 
 | ||||
| 
 | ||||
|     choice_data = ChatCompletionResponseStreamChoice( | ||||
|         index=0, | ||||
|  | @ -160,7 +161,9 @@ async def predict(query: str, history: List[List[str]], model_id: str): | |||
|         finish_reason="stop" | ||||
|     ) | ||||
|     chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") | ||||
|     yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False)) | ||||
|     yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) | ||||
|     yield '[DONE]' | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|  |  | |||
|  | @ -6,3 +6,4 @@ gradio | |||
| mdtex2html | ||||
| sentencepiece | ||||
| accelerate | ||||
| sse-starlette | ||||
		Loading…
	
		Reference in New Issue
	
	 mougua
						mougua