Spaces:
Runtime error
Runtime error
| import json | |
| import os | |
| import requests | |
| from typing import Optional, List, Any | |
| from pydantic import BaseModel, Field | |
| class LlmPredictParams(BaseModel): | |
| """ | |
| Параметры для предсказания LLM. | |
| """ | |
| system_prompt: Optional[str] = Field(None, description="Системный промпт.") | |
| user_prompt: Optional[str] = Field(None, description="Шаблон промпта для передачи от роли user.") | |
| n_predict: Optional[int] = None | |
| temperature: Optional[float] = None | |
| top_k: Optional[int] = None | |
| top_p: Optional[float] = None | |
| min_p: Optional[float] = None | |
| seed: Optional[int] = None | |
| repeat_penalty: Optional[float] = None | |
| repeat_last_n: Optional[int] = None | |
| retry_if_text_not_present: Optional[str] = None | |
| retry_count: Optional[int] = None | |
| presence_penalty: Optional[float] = None | |
| frequency_penalty: Optional[float] = None | |
| n_keep: Optional[int] = None | |
| cache_prompt: Optional[bool] = None | |
| stop: Optional[List[str]] = None | |
| class LlmParams(BaseModel): | |
| """ | |
| Основные параметры для LLM. | |
| """ | |
| url: str | |
| type: Optional[str] = None | |
| default: Optional[bool] = None | |
| template: Optional[str] = None | |
| predict_params: Optional[LlmPredictParams] = None | |
| class LlmApi: | |
| """ | |
| Класс для работы с API vllm. | |
| """ | |
| params: LlmParams = None | |
| def __init__(self, params: LlmParams): | |
| self.params = params | |
| def get_models(self) -> list[str]: | |
| """ | |
| Выполняет GET-запрос к API для получения списка доступных моделей. | |
| Возвращает: | |
| list[str]: Список идентификаторов моделей. | |
| Если произошла ошибка или данные недоступны, возвращается пустой список. | |
| Исключения: | |
| Все ошибки HTTP-запросов логируются в консоль, но не выбрасываются дальше. | |
| """ | |
| try: | |
| response = requests.get(f"{self.params.url}/v1/models", headers={"Content-Type": "application/json"}) | |
| if response.status_code == 200: | |
| json_data = response.json() | |
| result = [item['id'] for item in json_data.get('data', [])] | |
| return result | |
| except requests.RequestException as error: | |
| print('OpenAiService.getModels error:') | |
| print(error) | |
| return [] | |
| def create_messages(self, prompt: str) -> list[dict]: | |
| """ | |
| Создает сообщения для LLM на основе переданного промпта и системного промпта (если он задан). | |
| Args: | |
| prompt (str): Пользовательский промпт. | |
| Returns: | |
| list[dict]: Список сообщений с ролями и содержимым. | |
| """ | |
| actual_prompt = self.apply_llm_template_to_prompt(prompt) | |
| messages = [] | |
| if self.params.predict_params and self.params.predict_params.system_prompt: | |
| messages.append({"role": "system", "content": self.params.predict_params.system_prompt}) | |
| messages.append({"role": "user", "content": actual_prompt}) | |
| return messages | |
| def apply_llm_template_to_prompt(self, prompt: str) -> str: | |
| """ | |
| Применяет шаблон LLM к переданному промпту, если он задан. | |
| Args: | |
| prompt (str): Пользовательский промпт. | |
| Returns: | |
| str: Промпт с примененным шаблоном (или оригинальный, если шаблон отсутствует). | |
| """ | |
| actual_prompt = prompt | |
| if self.params.template is not None: | |
| actual_prompt = self.params.template.replace("{{PROMPT}}", actual_prompt) | |
| return actual_prompt | |
| def tokenize(self, prompt: str) -> Optional[dict]: | |
| """ | |
| Выполняет токенизацию переданного промпта. | |
| Args: | |
| prompt (str): Промпт для токенизации. | |
| Returns: | |
| Optional[dict]: Словарь с токенами и максимальной длиной модели, если запрос успешен. | |
| Если запрос неуспешен, возвращает None. | |
| """ | |
| model = self.get_models()[0] if self.get_models() else None | |
| if not model: | |
| print("No models available for tokenization.") | |
| return None | |
| actual_prompt = self.apply_llm_template_to_prompt(prompt) | |
| request_data = { | |
| "model": model, | |
| "prompt": actual_prompt, | |
| "add_special_tokens": False, | |
| } | |
| try: | |
| response = requests.post( | |
| f"{self.params.url}/tokenize", | |
| json=request_data, | |
| headers={"Content-Type": "application/json"}, | |
| ) | |
| if response.ok: | |
| data = response.json() | |
| if "tokens" in data: | |
| return {"tokens": data["tokens"], "maxLength": data.get("max_model_len")} | |
| elif response.status_code == 404: | |
| print("Tokenization endpoint not found (404).") | |
| else: | |
| print(f"Failed to tokenize: {response.status_code}") | |
| except requests.RequestException as e: | |
| print(f"Request failed: {e}") | |
| return None | |
| def detokenize(self, tokens: List[int]) -> Optional[str]: | |
| """ | |
| Выполняет детокенизацию переданных токенов. | |
| Args: | |
| tokens (List[int]): Список токенов для детокенизации. | |
| Returns: | |
| Optional[str]: Строка, полученная в результате детокенизации, если запрос успешен. | |
| Если запрос неуспешен, возвращает None. | |
| """ | |
| model = self.get_models()[0] if self.get_models() else None | |
| if not model: | |
| print("No models available for detokenization.") | |
| return None | |
| request_data = {"model": model, "tokens": tokens or []} | |
| try: | |
| response = requests.post( | |
| f"{self.params.url}/detokenize", | |
| json=request_data, | |
| headers={"Content-Type": "application/json"}, | |
| ) | |
| if response.ok: | |
| data = response.json() | |
| if "prompt" in data: | |
| return data["prompt"].strip() | |
| elif response.status_code == 404: | |
| print("Detokenization endpoint not found (404).") | |
| else: | |
| print(f"Failed to detokenize: {response.status_code}") | |
| except requests.RequestException as e: | |
| print(f"Request failed: {e}") | |
| return None | |
| def create_request(self, prompt: str) -> dict: | |
| """ | |
| Создает запрос для предсказания на основе параметров LLM. | |
| Args: | |
| prompt (str): Промпт для запроса. | |
| Returns: | |
| dict: Словарь с параметрами для выполнения запроса. | |
| """ | |
| llm_params = self.params | |
| models = self.get_models() | |
| if not models: | |
| raise ValueError("No models available to create a request.") | |
| model = models[0] | |
| request = { | |
| "stream": True, | |
| "model": model, | |
| } | |
| predict_params = llm_params.predict_params | |
| if predict_params: | |
| if predict_params.stop: | |
| # Фильтруем пустые строки в stop | |
| non_empty_stop = list(filter(lambda o: o != "", predict_params.stop)) | |
| if non_empty_stop: | |
| request["stop"] = non_empty_stop | |
| if predict_params.n_predict is not None: | |
| request["max_tokens"] = int(predict_params.n_predict or 0) | |
| request["temperature"] = float(predict_params.temperature or 0) | |
| if predict_params.top_k is not None: | |
| request["top_k"] = int(predict_params.top_k) | |
| if predict_params.top_p is not None: | |
| request["top_p"] = float(predict_params.top_p) | |
| if predict_params.min_p is not None: | |
| request["min_p"] = float(predict_params.min_p) | |
| if predict_params.seed is not None: | |
| request["seed"] = int(predict_params.seed) | |
| if predict_params.n_keep is not None: | |
| request["n_keep"] = int(predict_params.n_keep) | |
| if predict_params.cache_prompt is not None: | |
| request["cache_prompt"] = bool(predict_params.cache_prompt) | |
| if predict_params.repeat_penalty is not None: | |
| request["repetition_penalty"] = float(predict_params.repeat_penalty) | |
| if predict_params.repeat_last_n is not None: | |
| request["repeat_last_n"] = int(predict_params.repeat_last_n) | |
| if predict_params.presence_penalty is not None: | |
| request["presence_penalty"] = float(predict_params.presence_penalty) | |
| if predict_params.frequency_penalty is not None: | |
| request["frequency_penalty"] = float(predict_params.frequency_penalty) | |
| # Генерируем сообщения | |
| request["messages"] = self.create_messages(prompt) | |
| return request | |
| def trim_sources(self, sources: str, user_request: str, system_prompt: str = None) -> dict: | |
| """ | |
| Обрезает текст источников, чтобы уложиться в допустимое количество токенов. | |
| Args: | |
| sources (str): Текст источников. | |
| user_request (str): Запрос пользователя с примененным шаблоном без текста источников. | |
| system_prompt (str): Системный промпт, если нужен. | |
| Returns: | |
| dict: Словарь с результатом, количеством токенов до и после обрезки. | |
| """ | |
| # Токенизация текста источников | |
| sources_tokens_data = self.tokenize(sources) | |
| if sources_tokens_data is None: | |
| raise ValueError("Failed to tokenize sources.") | |
| max_token_count = sources_tokens_data.get("maxLength", 0) | |
| # Токены системного промпта | |
| system_prompt_token_count = 0 | |
| if system_prompt is not None: | |
| system_prompt_tokens = self.tokenize(system_prompt) | |
| system_prompt_token_count = len(system_prompt_tokens["tokens"]) if system_prompt_tokens else 0 | |
| # Оригинальное количество токенов | |
| original_token_count = len(sources_tokens_data["tokens"]) | |
| # Токенизация пользовательского промпта | |
| aux_prompt = self.apply_llm_template_to_prompt(user_request) | |
| aux_tokens_data = self.tokenize(aux_prompt) | |
| aux_token_count = len(aux_tokens_data["tokens"]) if aux_tokens_data else 0 | |
| # Максимально допустимое количество токенов для источников | |
| max_length = ( | |
| max_token_count | |
| - (self.params.predict_params.n_predict or 0) | |
| - aux_token_count | |
| - system_prompt_token_count | |
| ) | |
| max_length = max(max_length, 0) | |
| # Обрезка токенов источников | |
| if "tokens" in sources_tokens_data: | |
| sources_tokens_data["tokens"] = sources_tokens_data["tokens"][:max_length] | |
| detokenized_prompt = self.detokenize(sources_tokens_data["tokens"]) | |
| if detokenized_prompt is not None: | |
| sources = detokenized_prompt | |
| else: | |
| sources = sources[:max_length] | |
| else: | |
| sources = sources[:max_length] | |
| # Возврат результата | |
| return { | |
| "result": sources, | |
| "originalTokenCount": original_token_count, | |
| "slicedTokenCount": len(sources_tokens_data["tokens"]), | |
| } | |
| def predict(self, prompt: str) -> str: | |
| """ | |
| Выполняет SSE-запрос к API и возвращает собранный результат как текст. | |
| Args: | |
| prompt (str): Входной текст для предсказания. | |
| Returns: | |
| str: Сгенерированный текст. | |
| Raises: | |
| Exception: Если запрос завершился ошибкой. | |
| """ | |
| # Создание запроса | |
| request = self.create_request(prompt) | |
| print(f"Predict request. Url: {self.params.url}") | |
| response = requests.post( | |
| f"{self.params.url}/v1/chat/completions", | |
| headers={"Content-Type": "application/json"}, | |
| json=request, | |
| stream=True # Для обработки SSE | |
| ) | |
| if not response.ok: | |
| raise Exception(f"Failed to generate text: {response.text}") | |
| # Обработка SSE-ответа | |
| generated_text = "" | |
| for line in response.iter_lines(decode_unicode=True): | |
| if line.startswith("data: "): | |
| try: | |
| data = json.loads(line[len("data: "):].strip()) | |
| # Проверка завершения генерации | |
| if data == "[DONE]": | |
| break | |
| # Получение текста из ответа | |
| if "choices" in data and data["choices"]: | |
| token_value = data["choices"][0].get("delta", {}).get("content", "") | |
| generated_text += token_value.replace("</s>", "") | |
| except json.JSONDecodeError: | |
| continue # Игнорирование строк, которые не удалось декодировать | |
| return generated_text | |