Spaces:
Runtime error
Runtime error
| import json | |
| from threading import Lock | |
| from typing import ( | |
| Optional, | |
| Union, | |
| Iterator, | |
| Dict, | |
| Any, | |
| AsyncIterator, | |
| ) | |
| import anyio | |
| from anyio.streams.memory import MemoryObjectSendStream | |
| from fastapi import Depends, HTTPException, Request | |
| from fastapi.responses import JSONResponse | |
| from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer | |
| from loguru import logger | |
| from pydantic import BaseModel | |
| from starlette.concurrency import iterate_in_threadpool | |
| from api.config import SETTINGS | |
| from api.utils.compat import model_json, model_dump | |
| from api.utils.constants import ErrorCode | |
| from api.utils.protocol import ( | |
| ChatCompletionCreateParams, | |
| CompletionCreateParams, | |
| ErrorResponse, | |
| ) | |
| llama_outer_lock = Lock() | |
| llama_inner_lock = Lock() | |
| async def check_api_key( | |
| auth: Optional[HTTPAuthorizationCredentials] = Depends(HTTPBearer(auto_error=False)), | |
| ): | |
| if not SETTINGS.api_keys: | |
| # api_keys not set; allow all | |
| return None | |
| if auth is None or (token := auth.credentials) not in SETTINGS.api_keys: | |
| raise HTTPException( | |
| status_code=401, | |
| detail={ | |
| "error": { | |
| "message": "", | |
| "type": "invalid_request_error", | |
| "param": None, | |
| "code": "invalid_api_key", | |
| } | |
| }, | |
| ) | |
| return token | |
| def create_error_response(code: int, message: str) -> JSONResponse: | |
| return JSONResponse(model_dump(ErrorResponse(message=message, code=code)), status_code=500) | |
| async def handle_request( | |
| request: Union[CompletionCreateParams, ChatCompletionCreateParams], | |
| stop: Dict[str, Any] = None, | |
| chat: bool = True, | |
| ) -> Union[Union[CompletionCreateParams, ChatCompletionCreateParams], JSONResponse]: | |
| error_check_ret = check_requests(request) | |
| if error_check_ret is not None: | |
| return error_check_ret | |
| # stop settings | |
| _stop, _stop_token_ids = [], [] | |
| if stop is not None: | |
| _stop_token_ids = stop.get("token_ids", []) | |
| _stop = stop.get("strings", []) | |
| request.stop = request.stop or [] | |
| if isinstance(request.stop, str): | |
| request.stop = [request.stop] | |
| if chat and ("qwen" in SETTINGS.model_name.lower() and request.functions): | |
| request.stop.append("Observation:") | |
| request.stop = list(set(_stop + request.stop)) | |
| request.stop_token_ids = request.stop_token_ids or [] | |
| request.stop_token_ids = list(set(_stop_token_ids + request.stop_token_ids)) | |
| request.top_p = max(request.top_p, 1e-5) | |
| if request.temperature <= 1e-5: | |
| request.top_p = 1.0 | |
| return request | |
| def check_requests(request: Union[CompletionCreateParams, ChatCompletionCreateParams]) -> Optional[JSONResponse]: | |
| # Check all params | |
| if request.max_tokens is not None and request.max_tokens <= 0: | |
| return create_error_response( | |
| ErrorCode.PARAM_OUT_OF_RANGE, | |
| f"{request.max_tokens} is less than the minimum of 1 - 'max_tokens'", | |
| ) | |
| if request.n is not None and request.n <= 0: | |
| return create_error_response( | |
| ErrorCode.PARAM_OUT_OF_RANGE, | |
| f"{request.n} is less than the minimum of 1 - 'n'", | |
| ) | |
| if request.temperature is not None and request.temperature < 0: | |
| return create_error_response( | |
| ErrorCode.PARAM_OUT_OF_RANGE, | |
| f"{request.temperature} is less than the minimum of 0 - 'temperature'", | |
| ) | |
| if request.temperature is not None and request.temperature > 2: | |
| return create_error_response( | |
| ErrorCode.PARAM_OUT_OF_RANGE, | |
| f"{request.temperature} is greater than the maximum of 2 - 'temperature'", | |
| ) | |
| if request.top_p is not None and request.top_p < 0: | |
| return create_error_response( | |
| ErrorCode.PARAM_OUT_OF_RANGE, | |
| f"{request.top_p} is less than the minimum of 0 - 'top_p'", | |
| ) | |
| if request.top_p is not None and request.top_p > 1: | |
| return create_error_response( | |
| ErrorCode.PARAM_OUT_OF_RANGE, | |
| f"{request.top_p} is greater than the maximum of 1 - 'temperature'", | |
| ) | |
| if request.stop is None or isinstance(request.stop, (str, list)): | |
| return None | |
| else: | |
| return create_error_response( | |
| ErrorCode.PARAM_OUT_OF_RANGE, | |
| f"{request.stop} is not valid under any of the given schemas - 'stop'", | |
| ) | |
| async def get_event_publisher( | |
| request: Request, | |
| inner_send_chan: MemoryObjectSendStream, | |
| iterator: Union[Iterator, AsyncIterator], | |
| ): | |
| async with inner_send_chan: | |
| try: | |
| if SETTINGS.engine not in ["vllm", "tgi"]: | |
| async for chunk in iterate_in_threadpool(iterator): | |
| if isinstance(chunk, BaseModel): | |
| chunk = model_json(chunk) | |
| elif isinstance(chunk, dict): | |
| chunk = json.dumps(chunk, ensure_ascii=False) | |
| await inner_send_chan.send(dict(data=chunk)) | |
| if await request.is_disconnected(): | |
| raise anyio.get_cancelled_exc_class()() | |
| if SETTINGS.interrupt_requests and llama_outer_lock.locked(): | |
| await inner_send_chan.send(dict(data="[DONE]")) | |
| raise anyio.get_cancelled_exc_class()() | |
| else: | |
| async for chunk in iterator: | |
| chunk = model_json(chunk) | |
| await inner_send_chan.send(dict(data=chunk)) | |
| if await request.is_disconnected(): | |
| raise anyio.get_cancelled_exc_class()() | |
| await inner_send_chan.send(dict(data="[DONE]")) | |
| except anyio.get_cancelled_exc_class() as e: | |
| logger.info("disconnected") | |
| with anyio.move_on_after(1, shield=True): | |
| logger.info(f"Disconnected from client (via refresh/close) {request.client}") | |
| raise e | |