Spaces:
Runtime error
Runtime error
| import json | |
| import logging | |
| import os | |
| from typing import Annotated, AsyncGenerator, List, Optional | |
| from fastapi import APIRouter, Depends, HTTPException | |
| from fastapi.responses import StreamingResponse | |
| import common.dependencies as DI | |
| from common import auth | |
| from common.configuration import Configuration | |
| from components.llm.common import (ChatRequest, LlmParams, LlmPredictParams, | |
| Message) | |
| from components.llm.deepinfra_api import DeepInfraApi | |
| from components.llm.utils import append_llm_response_to_history | |
| from components.services.dataset import DatasetService | |
| from components.services.dialogue import DialogueService | |
| from components.services.entity import EntityService | |
| from components.services.llm_config import LLMConfigService | |
| from components.services.llm_prompt import LlmPromptService | |
| from components.services.log import LogService | |
| from schemas.log import LogCreateSchema | |
| router = APIRouter(prefix='/llm', tags=['LLM chat']) | |
| logger = logging.getLogger(__name__) | |
| conf = DI.get_config() | |
| llm_params = LlmParams( | |
| **{ | |
| "url": conf.llm_config.base_url, | |
| "model": conf.llm_config.model, | |
| "tokenizer": "unsloth/Llama-3.3-70B-Instruct", | |
| "type": "deepinfra", | |
| "default": True, | |
| "predict_params": LlmPredictParams( | |
| temperature=0.15, | |
| top_p=0.95, | |
| min_p=0.05, | |
| seed=42, | |
| repetition_penalty=1.2, | |
| presence_penalty=1.1, | |
| n_predict=2000, | |
| ), | |
| "api_key": os.environ.get(conf.llm_config.api_key_env), | |
| "context_length": 128000, | |
| } | |
| ) | |
| # TODO: унести в DI | |
| llm_api = DeepInfraApi(params=llm_params) | |
| # TODO: Вынести | |
| def get_last_user_message(chat_request: ChatRequest) -> Optional[Message]: | |
| return next( | |
| ( | |
| msg | |
| for msg in reversed(chat_request.history) | |
| if msg.role == "user" | |
| ), | |
| None, | |
| ) | |
| def insert_search_results_to_message( | |
| chat_request: ChatRequest, new_content: str | |
| ) -> bool: | |
| for msg in reversed(chat_request.history): | |
| if msg.role == "user" and ( | |
| msg.searchResults is None or not msg.searchResults | |
| ): | |
| msg.content = new_content | |
| return True | |
| return False | |
| def try_insert_search_results( | |
| chat_request: ChatRequest, search_results: str | |
| ) -> bool: | |
| for msg in reversed(chat_request.history): | |
| if msg.role == "user": | |
| msg.searchResults = search_results | |
| msg.searchEntities = [] | |
| return True | |
| return False | |
| def try_insert_reasoning( | |
| chat_request: ChatRequest, reasoning: str | |
| ): | |
| for msg in reversed(chat_request.history): | |
| if msg.role == "user": | |
| msg.reasoning = reasoning | |
| def collapse_history_to_first_message(chat_history: List[Message]) -> List[Message]: | |
| """ | |
| Сворачивает историю в первое сообщение и возвращает новый объект ChatRequest. | |
| Формат: | |
| <history> | |
| <user> | |
| текст сообщения | |
| </user> | |
| <reasoning> | |
| текст reasoning | |
| </reasoning> | |
| <search-results> | |
| текст search-results | |
| </search-results> | |
| <assistant> | |
| текст ответа | |
| </assistant> | |
| </history> | |
| <last-request> | |
| <reasoning> | |
| текст reasoning | |
| </reasoning> | |
| <search-results> | |
| текст search-results | |
| </search-results> | |
| user: | |
| текст последнего запроса | |
| </last-request> | |
| assistant: | |
| """ | |
| if not chat_history: | |
| return [] | |
| last_user_message = chat_history[-1] | |
| if chat_history[-1].role != "user": | |
| logger.warning("Last message is not user message") | |
| # Собираем историю в одну строку | |
| collapsed_content = [] | |
| collapsed_content.append("<INPUT><history>\n") | |
| for msg in chat_history[:-1]: | |
| if msg.content.strip(): | |
| tabulated_content = msg.content.strip().replace("\n", "\n\t\t") | |
| collapsed_content.append(f"\t<{msg.role.strip()}>\n\t\t{tabulated_content}\n\t</{msg.role.strip()}>\n") | |
| if msg.role == "user": | |
| tabulated_reasoning = msg.reasoning.strip().replace("\n", "\n\t\t") | |
| tabulated_search_results = msg.searchResults.strip().replace("\n", "\n\t\t") | |
| # collapsed_content.append(f"\t<reasoning>\n\t\t{tabulated_reasoning}\n\t</reasoning>\n") | |
| # collapsed_content.append(f"\t<search-results>\n\t\t{tabulated_search_results}\n\t</search-results>\n") | |
| collapsed_content.append("</history>\n") | |
| collapsed_content.append("<last-request>\n") | |
| if last_user_message.content.strip(): | |
| tabulated_content = last_user_message.content.strip().replace("\n", "\n\t\t") | |
| tabulated_reasoning = last_user_message.reasoning.strip().replace("\n", "\n\t\t") | |
| tabulated_search_results = last_user_message.searchResults.strip().replace("\n", "\n\t\t") | |
| # collapsed_content.append(f"\t<reasoning>\n\t\t{tabulated_reasoning}\n\t</reasoning>\n") | |
| collapsed_content.append(f"\t<search-results>\n\t\t{tabulated_search_results}\n\t</search-results>\n") | |
| collapsed_content.append(f"\t<user>\n\t\t{tabulated_content}\n</user>\n") | |
| collapsed_content.append("</last-request>\n") | |
| collapsed_content.append("</INPUT><OUTPUT>\n") | |
| new_content = "".join(collapsed_content) | |
| new_message = Message( | |
| role='user', | |
| content=new_content, | |
| searchResults='' | |
| ) | |
| return [new_message] | |
| async def sse_generator(request: ChatRequest, llm_api: DeepInfraApi, system_prompt: str, | |
| predict_params: LlmPredictParams, | |
| dataset_service: DatasetService, | |
| entity_service: EntityService, | |
| dialogue_service: DialogueService, | |
| log_service: LogService, | |
| current_user: auth.User) -> AsyncGenerator[str, None]: | |
| """ | |
| Генератор для стриминга ответа LLM через SSE. | |
| """ | |
| # Создаем экземпляр "сквозного" лога через весь процесс | |
| log = LogCreateSchema(user_name=current_user.username, chat_id=request.chat_id) | |
| try: | |
| old_history = request.history | |
| # Сохраняем последнее сообщение в лог как исходный пользовательский запрос | |
| last_message = get_last_user_message(request) | |
| log.user_request = last_message.content if last_message is not None else None | |
| new_history = [Message( | |
| role=msg.role, | |
| content=msg.content, | |
| reasoning=msg.reasoning, | |
| searchResults='', #msg.searchResults[:10000] + "..." if msg.searchResults else '', | |
| searchEntities=[], | |
| ) for msg in old_history] | |
| request.history = new_history | |
| qe_result = await dialogue_service.get_qe_result(request.history) | |
| # Запись результата qe в лог | |
| log.qe_result = qe_result.model_dump_json() | |
| try_insert_reasoning(request, qe_result.debug_message) | |
| # qe_debug_event = { | |
| # "event": "debug", | |
| # "data": { | |
| # "text": qe_result.debug_message | |
| # } | |
| # } | |
| # yield f"data: {json.dumps(qe_debug_event, ensure_ascii=False)}\n\n" | |
| qe_event = { | |
| "event": "reasoning", | |
| "data": { | |
| "text": qe_result.debug_message | |
| } | |
| } | |
| yield f"data: {json.dumps(qe_event, ensure_ascii=False)}\n\n" | |
| except Exception as e: | |
| log.error = "Error in QE block: " + str(e) | |
| log_service.create(log) | |
| logger.error(f"Error in SSE chat stream while dialogue_service.get_qe_result: {str(e)}", stack_info=True) | |
| yield "data: {\"event\": \"error\", \"data\":\""+str(e)+"\" }\n\n" | |
| qe_result = dialogue_service.get_qe_result_from_chat(request.history) | |
| try: | |
| if qe_result.use_search and qe_result.search_query is not None: | |
| dataset = dataset_service.get_current_dataset() | |
| if dataset is None: | |
| raise HTTPException(status_code=400, detail="Dataset not found") | |
| _, chunk_ids, scores = entity_service.search_similar( | |
| qe_result.search_query, | |
| dataset.id, | |
| [], | |
| ) | |
| text_chunks = await entity_service.build_text_async(chunk_ids, dataset.id, scores) | |
| # Запись результатов поиска в лог | |
| log.search_result = text_chunks | |
| search_results_event = { | |
| "event": "search_results", | |
| "data": { | |
| "text": text_chunks, | |
| "ids": chunk_ids | |
| } | |
| } | |
| yield f"data: {json.dumps(search_results_event, ensure_ascii=False)}\n\n" | |
| # new_message = f'<search-results>\n{text_chunks}\n</search-results>\n{last_query.content}' | |
| try_insert_search_results(request, text_chunks) | |
| except Exception as e: | |
| log.error = "Error in vector search block: " + str(e) | |
| log_service.create(log) | |
| logger.error(f"Error in SSE chat stream while searching: {str(e)}", stack_info=True) | |
| yield "data: {\"event\": \"error\", \"data\":\""+str(e)+"\" }\n\n" | |
| log_error = None | |
| try: | |
| # Сворачиваем историю в первое сообщение | |
| collapsed_request = ChatRequest( | |
| history=collapse_history_to_first_message(request.history), | |
| chat_id = request.chat_id | |
| ) | |
| log.llm_result = '' | |
| # Стриминг токенов ответа | |
| async for token in llm_api.get_predict_chat_generator(collapsed_request, system_prompt, predict_params): | |
| token_event = {"event": "token", "data": token} | |
| log.llm_result += token | |
| yield f"data: {json.dumps(token_event, ensure_ascii=False)}\n\n" | |
| # Финальное событие | |
| yield "data: {\"event\": \"done\"}\n\n" | |
| except Exception as e: | |
| log.error = "Error in llm inference block: " + str(e) | |
| logger.error(f"Error in SSE chat stream while generating response: {str(e)}", stack_info=True) | |
| yield "data: {\"event\": \"error\", \"data\":\""+str(e)+"\" }\n\n" | |
| finally: | |
| log_service.create(log) | |
| async def chat_stream( | |
| request: ChatRequest, | |
| config: Annotated[Configuration, Depends(DI.get_config)], | |
| llm_api: Annotated[DeepInfraApi, Depends(DI.get_llm_service)], | |
| prompt_service: Annotated[LlmPromptService, Depends(DI.get_llm_prompt_service)], | |
| llm_config_service: Annotated[LLMConfigService, Depends(DI.get_llm_config_service)], | |
| entity_service: Annotated[EntityService, Depends(DI.get_entity_service)], | |
| dataset_service: Annotated[DatasetService, Depends(DI.get_dataset_service)], | |
| dialogue_service: Annotated[DialogueService, Depends(DI.get_dialogue_service)], | |
| log_service: Annotated[LogService, Depends(DI.get_log_service)], | |
| current_user: Annotated[any, Depends(auth.get_current_user)] | |
| ): | |
| try: | |
| p = llm_config_service.get_default() | |
| system_prompt = prompt_service.get_default() | |
| predict_params = LlmPredictParams( | |
| temperature=p.temperature, | |
| top_p=p.top_p, | |
| min_p=p.min_p, | |
| seed=p.seed, | |
| frequency_penalty=p.frequency_penalty, | |
| presence_penalty=p.presence_penalty, | |
| n_predict=p.n_predict, | |
| stop=[], | |
| ) | |
| headers = { | |
| "Content-Type": "text/event-stream", | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "Access-Control-Allow-Origin": "*", | |
| } | |
| return StreamingResponse( | |
| sse_generator(request, llm_api, system_prompt.text, predict_params, dataset_service, entity_service, dialogue_service, log_service, current_user), | |
| media_type="text/event-stream", | |
| headers=headers | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in SSE chat stream: {str(e)}", stack_info=True) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def chat( | |
| request: ChatRequest, | |
| config: Annotated[Configuration, Depends(DI.get_config)], | |
| llm_api: Annotated[DeepInfraApi, Depends(DI.get_llm_service)], | |
| prompt_service: Annotated[LlmPromptService, Depends(DI.get_llm_prompt_service)], | |
| llm_config_service: Annotated[LLMConfigService, Depends(DI.get_llm_config_service)], | |
| entity_service: Annotated[EntityService, Depends(DI.get_entity_service)], | |
| dataset_service: Annotated[DatasetService, Depends(DI.get_dataset_service)], | |
| dialogue_service: Annotated[DialogueService, Depends(DI.get_dialogue_service)], | |
| ): | |
| try: | |
| p = llm_config_service.get_default() | |
| system_prompt = prompt_service.get_default() | |
| predict_params = LlmPredictParams( | |
| temperature=p.temperature, | |
| top_p=p.top_p, | |
| min_p=p.min_p, | |
| seed=p.seed, | |
| frequency_penalty=p.frequency_penalty, | |
| presence_penalty=p.presence_penalty, | |
| n_predict=p.n_predict, | |
| stop=[], | |
| ) | |
| try: | |
| qe_result = await dialogue_service.get_qe_result(request.history) | |
| except Exception as e: | |
| logger.error(f"Error in chat while dialogue_service.get_qe_result: {str(e)}", stack_info=True) | |
| qe_result = dialogue_service.get_qe_result_from_chat(request.history) | |
| last_message = get_last_user_message(request) | |
| logger.info(f"qe_result: {qe_result}") | |
| if qe_result.use_search and qe_result.search_query is not None: | |
| dataset = dataset_service.get_current_dataset() | |
| if dataset is None: | |
| raise HTTPException(status_code=400, detail="Dataset not found") | |
| logger.info(f"qe_result.search_query: {qe_result.search_query}") | |
| previous_entities = [msg.searchEntities for msg in request.history] | |
| previous_entities, chunk_ids, scores = entity_service.search_similar( | |
| qe_result.search_query, dataset.id, previous_entities | |
| ) | |
| chunks = entity_service.chunk_repository.get_entities_by_ids(chunk_ids) | |
| logger.info(f"chunk_ids: {chunk_ids[:3]}...{chunk_ids[-3:]}") | |
| logger.info(f"scores: {scores[:3]}...{scores[-3:]}") | |
| text_chunks = await entity_service.build_text_async(chunk_ids, dataset.id, scores) | |
| logger.info(f"text_chunks: {text_chunks[:3]}...{text_chunks[-3:]}") | |
| new_message = f'{last_message.content} /n<search-results>/n{text_chunks}/n</search-results>' | |
| insert_search_results_to_message(request, new_message) | |
| logger.info(f"request: {request}") | |
| response = await llm_api.predict_chat_stream( | |
| request, system_prompt.text, predict_params | |
| ) | |
| result = append_llm_response_to_history(request, response) | |
| return result | |
| except Exception as e: | |
| logger.error( | |
| f"Error processing LLM request: {str(e)}", stack_info=True, stacklevel=10 | |
| ) | |
| return {"error": str(e)} |