Spaces:
Running
Running
| import fastapi | |
| import json | |
| import markdown | |
| import uvicorn | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | |
| from ctransformers import AutoModelForCausalLM | |
| from pydantic import BaseModel | |
| from sse_starlette.sse import EventSourceResponse | |
| config = {"max_seq_len": 4096} | |
| llm = AutoModelForCausalLM.from_pretrained('TheBloke/MPT-7B-Storywriter-GGML', | |
| model_file='mpt-7b-storywriter.ggmlv3.q4_0.bin', | |
| model_type='mpt') | |
| app = fastapi.FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def index(): | |
| with open("README.md", "r", encoding="utf-8") as readme_file: | |
| md_template_string = readme_file.read() | |
| html_content = markdown.markdown(md_template_string) | |
| return HTMLResponse(content=html_content, status_code=200) | |
| class ChatCompletionRequest(BaseModel): | |
| prompt: str | |
| async def demo(): | |
| html_content = """ | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <style> | |
| #logs { | |
| background-color: black; | |
| color:white; | |
| height:600px; | |
| overflow-x: hidden; | |
| overflow-y: auto; | |
| text-align: left; | |
| padding-left:10px; | |
| } | |
| </style> | |
| </head> | |
| <body> | |
| <h1>StoryWriter Demo</h1> | |
| <div id="logs"> | |
| </div> | |
| <script> | |
| var source = new EventSource("http://localhost:8000/stream"); | |
| source.onmessage = function(event) { | |
| document.getElementById("logs").innerHTML += event.data + "<br>"; | |
| }; | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| return HTMLResponse(content=html_content, status_code=200) | |
| async def chat(prompt = "Once upon a time there was a "): | |
| completion = llm(prompt) | |
| async def server_sent_events(chat_chunks): | |
| yield prompt | |
| for chat_chunk in chat_chunks: | |
| yield chat_chunk | |
| yield "" | |
| return StreamingResponse(server_sent_events(completion)) | |
| async def chat(request: ChatCompletionRequest, response_mode=None): | |
| completion = llm(request.prompt) | |
| async def server_sent_events( | |
| chat_chunks, | |
| ): | |
| for chat_chunk in chat_chunks: | |
| yield dict(data=json.dumps(chat_chunk)) | |
| yield dict(data="[DONE]") | |
| chunks = completion | |
| return EventSourceResponse( | |
| server_sent_events(chunks), | |
| ) | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |