Spaces:
Runtime error
Runtime error
| import logging | |
| from typing import Annotated, Optional, Tuple | |
| import os | |
| from fastapi import APIRouter, BackgroundTasks, HTTPException, Response, UploadFile, Depends | |
| from components.llm.common import LlmParams, LlmPredictParams, Message | |
| from components.llm.deepinfra_api import DeepInfraApi | |
| from components.llm.llm_api import LlmApi | |
| from components.llm.common import ChatRequest | |
| from common.constants import PROMPT | |
| from components.llm.prompts import SYSTEM_PROMPT | |
| from components.llm.utils import append_llm_response_to_history, convert_to_openai_format | |
| from components.nmd.aggregate_answers import preprocessed_chunks | |
| from components.nmd.llm_chunk_search import LLMChunkSearch | |
| from components.services.dataset import DatasetService | |
| from common.configuration import Configuration, Query, SummaryChunks | |
| from components.datasets.dispatcher import Dispatcher | |
| from common.exceptions import LLMResponseException | |
| from components.dbo.models.log import Log | |
| from components.services.llm_config import LLMConfigService | |
| from components.services.llm_prompt import LlmPromptService | |
| from schemas.dataset import (Dataset, DatasetExpanded, DatasetProcessing, | |
| SortQuery, SortQueryList) | |
| import common.dependencies as DI | |
| from sqlalchemy.orm import Session | |
| router = APIRouter(prefix='/llm') | |
| logger = logging.getLogger(__name__) | |
| conf = DI.get_config() | |
| llm_params = LlmParams(**{ | |
| "url": conf.llm_config.base_url, | |
| "model": conf.llm_config.model, | |
| "tokenizer": "unsloth/Llama-3.3-70B-Instruct", | |
| "type": "deepinfra", | |
| "default": True, | |
| "predict_params": LlmPredictParams( | |
| temperature=0.15, top_p=0.95, min_p=0.05, seed=42, | |
| repetition_penalty=1.2, presence_penalty=1.1, n_predict=2000 | |
| ), | |
| "api_key": os.environ.get(conf.llm_config.api_key_env), | |
| "context_length": 128000 | |
| }) | |
| #TODO: унести в DI | |
| llm_api = DeepInfraApi(params=llm_params) | |
| def get_chunks(query: Query, dispatcher: Annotated[Dispatcher, Depends(DI.get_dispatcher)]) -> SummaryChunks: | |
| logger.info(f"Handling POST request to /chunks with query: {query.query}") | |
| try: | |
| result = dispatcher.search_answer(query) | |
| logger.info("Successfully retrieved chunks") | |
| return result | |
| except Exception as e: | |
| logger.error(f"Error retrieving chunks: {str(e)}") | |
| raise e | |
| def llm_answer(query: str, answer_chunks: SummaryChunks, config: Configuration | |
| ) -> Tuple[str, str, str, int]: | |
| """ | |
| Метод для поиска правильного ответа с помощью LLM. | |
| Args: | |
| query: Запрос. | |
| answer_chunks: Ответы векторного поиска и elastic. | |
| Returns: | |
| Возвращает исходные chunks из поисков, и chunk который выбрала модель. | |
| """ | |
| prompt = PROMPT | |
| llm_search = LLMChunkSearch(config.llm_config, PROMPT, logger) | |
| return llm_search.llm_chunk_search(query, answer_chunks, prompt) | |
| def get_llm_answer(query: Query, chunks: SummaryChunks, db: Annotated[Session, Depends(DI.get_db)], config: Annotated[Configuration, Depends(DI.get_config)]): | |
| logger.info(f"Handling POST request to /answer_llm with query: {query.query}") | |
| try: | |
| text_chunks, answer_llm, llm_prompt, _ = llm_answer(query.query, chunks, config) | |
| if not answer_llm: | |
| logger.error("LLM returned empty response") | |
| raise LLMResponseException() | |
| log_entry = Log( | |
| llmPrompt=llm_prompt, | |
| llmResponse=answer_llm, | |
| userRequest=query.query, | |
| query_type=chunks.query_type, | |
| userName=query.userName, | |
| ) | |
| with db() as session: | |
| session.add(log_entry) | |
| session.commit() | |
| session.refresh(log_entry) | |
| logger.info(f"Successfully processed LLM request, log_id: {log_entry.id}") | |
| return { | |
| "answer_llm": answer_llm, | |
| "log_id": log_entry.id, | |
| } | |
| except Exception as e: | |
| logger.error(f"Error processing LLM request: {str(e)}") | |
| raise e | |
| async def chat(request: ChatRequest, config: Annotated[Configuration, Depends(DI.get_config)], llm_api: Annotated[DeepInfraApi, Depends(DI.get_llm_service)], prompt_service: Annotated[LlmPromptService, Depends(DI.get_llm_prompt_service)], llm_config_service: Annotated[LLMConfigService, Depends(DI.get_llm_config_service)], dispatcher: Annotated[Dispatcher, Depends(DI.get_dispatcher)]): | |
| try: | |
| p = llm_config_service.get_default() | |
| system_prompt = prompt_service.get_default() | |
| predict_params = LlmPredictParams( | |
| temperature=p.temperature, top_p=p.top_p, min_p=p.min_p, seed=p.seed, | |
| frequency_penalty=p.frequency_penalty, presence_penalty=p.presence_penalty, n_predict=p.n_predict, stop=[] | |
| ) | |
| #TODO: Вынести | |
| def get_last_user_message(chat_request: ChatRequest) -> Optional[Message]: | |
| return next( | |
| ( | |
| msg for msg in reversed(chat_request.history) | |
| if msg.role == "user" and (msg.searchResults is None or not msg.searchResults) | |
| ), | |
| None | |
| ) | |
| def insert_search_results_to_message(chat_request: ChatRequest, new_content: str) -> bool: | |
| for msg in reversed(chat_request.history): | |
| if msg.role == "user" and (msg.searchResults is None or not msg.searchResults): | |
| msg.content = new_content | |
| return True | |
| return False | |
| last_query = get_last_user_message(request) | |
| search_result = None | |
| if last_query: | |
| search_result = dispatcher.search_answer(Query(query=last_query.content, query_abbreviation=last_query.content)) | |
| text_chunks = preprocessed_chunks(search_result, None, logger) | |
| new_message = f'{last_query.content} /n<search-results>/n{text_chunks}/n</search-results>' | |
| insert_search_results_to_message(request, new_message) | |
| response = await llm_api.predict_chat_stream(request, system_prompt.text, predict_params) | |
| result = append_llm_response_to_history(request, response) | |
| return result | |
| except Exception as e: | |
| logger.error(f"Error processing LLM request: {str(e)}", stack_info=True, stacklevel=10) | |
| return {"error": str(e)} |