Spaces:
Paused
Paused
| import json | |
| from typing import List | |
| import torch | |
| from fastapi import FastAPI, Request, status, HTTPException | |
| from pydantic import BaseModel | |
| from torch.cuda import get_device_properties | |
| from transformers import AutoModel, AutoTokenizer | |
| from sse_starlette.sse import EventSourceResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import uvicorn | |
| from pyngrok import ngrok, conf | |
| import os | |
| os.environ['TRANSFORMERS_CACHE'] = ".cache" | |
| bits = 4 | |
| kernel_path = "models/models--silver--chatglm-6b-int4-slim/quantization_kernels.so" | |
| model_path = "./models/models--silver--chatglm-6b-int4-slim/snapshots/02e096b3805c579caf5741a6d8eddd5ba7a74e0d" | |
| cache_dir = './models' | |
| model_name = 'chatglm-6b-int4' | |
| min_memory = 5.5 | |
| tokenizer = None | |
| model = None | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def init(): | |
| global tokenizer, model | |
| tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, cache_dir=cache_dir) | |
| model = AutoModel.from_pretrained(model_path, trust_remote_code=True, cache_dir=cache_dir) | |
| if torch.cuda.is_available() and get_device_properties(0).total_memory / 1024 ** 3 > min_memory: | |
| model = model.half().quantize(bits=bits).cuda() | |
| print("Using GPU") | |
| else: | |
| model = model.float().quantize(bits=bits) | |
| if torch.cuda.is_available(): | |
| print("Total Memory: ", get_device_properties(0).total_memory / 1024 ** 3) | |
| else: | |
| print("No GPU available") | |
| print("Using CPU") | |
| model = model.eval() | |
| ngrok_connect() | |
| class Message(BaseModel): | |
| role: str | |
| content: str | |
| class Body(BaseModel): | |
| messages: List[Message] | |
| model: str | |
| stream: bool | |
| max_tokens: int | |
| def read_root(): | |
| return {"Hello": "World!"} | |
| async def completions(body: Body, request: Request): | |
| if not body.stream or body.model != model_name: | |
| raise HTTPException(status.HTTP_400_BAD_REQUEST, "Not Implemented") | |
| question = body.messages[-1] | |
| if question.role == 'user': | |
| question = question.content | |
| else: | |
| raise HTTPException(status.HTTP_400_BAD_REQUEST, "No Question Found") | |
| user_question = '' | |
| history = [] | |
| for message in body.messages: | |
| if message.role == 'user': | |
| user_question = message.content | |
| elif message.role == 'system' or message.role == 'assistant': | |
| assistant_answer = message.content | |
| history.append((user_question, assistant_answer)) | |
| async def event_generator(): | |
| for response in model.stream_chat(tokenizer, question, history, max_length=max(2048, body.max_tokens)): | |
| if await request.is_disconnected(): | |
| return | |
| yield json.dumps({"response": response[0]}) | |
| yield "[DONE]" | |
| return EventSourceResponse(event_generator()) | |
| def ngrok_connect(): | |
| conf.set_default(conf.PyngrokConfig(ngrok_path="./ngrok")) | |
| ngrok.set_auth_token(os.environ["ngrok_token"]) | |
| http_tunnel = ngrok.connect(8000) | |
| print(http_tunnel.public_url) | |
| if __name__ == "__main__": | |
| uvicorn.run("main:app", reload=True, app_dir=".") | |