| from __future__ import annotations | |
| import random | |
| import string | |
| from pathlib import Path | |
| from ..typing import Messages, Cookies, AsyncIterator, Iterator | |
| from ..tools.files import get_bucket_dir, read_bucket | |
| from .. import debug | |
| def to_string(value) -> str: | |
| if isinstance(value, str): | |
| return value | |
| elif isinstance(value, dict): | |
| if "text" in value: | |
| return value["text"] | |
| elif "name" in value: | |
| return "" | |
| elif "bucket_id" in value: | |
| bucket_dir = Path(get_bucket_dir(value.get("bucket_id"))) | |
| return "".join(read_bucket(bucket_dir)) | |
| return "" | |
| elif isinstance(value, list): | |
| return "".join([to_string(v) for v in value if v.get("type", "text") == "text"]) | |
| elif value is None: | |
| return "" | |
| return str(value) | |
| def render_messages(messages: Messages) -> Iterator: | |
| for idx, message in enumerate(messages): | |
| if isinstance(message, dict) and isinstance(message.get("content"), list): | |
| yield { | |
| **message, | |
| "content": to_string(message["content"]), | |
| } | |
| else: | |
| yield message | |
| def format_prompt(messages: Messages, add_special_tokens: bool = False, do_continue: bool = False, include_system: bool = True) -> str: | |
| """ | |
| Format a series of messages into a single string, optionally adding special tokens. | |
| Args: | |
| messages (Messages): A list of message dictionaries, each containing 'role' and 'content'. | |
| add_special_tokens (bool): Whether to add special formatting tokens. | |
| Returns: | |
| str: A formatted string containing all messages. | |
| """ | |
| if not add_special_tokens and len(messages) <= 1: | |
| return to_string(messages[0]["content"]) | |
| messages = [ | |
| (message["role"], to_string(message["content"])) | |
| for message in messages | |
| if include_system or message.get("role") not in ("developer", "system") | |
| ] | |
| formatted = "\n".join([ | |
| f'{role.capitalize()}: {content}' | |
| for role, content in messages | |
| if content.strip() | |
| ]) | |
| if do_continue: | |
| return formatted | |
| return f"{formatted}\nAssistant:" | |
| def get_system_prompt(messages: Messages) -> str: | |
| return "\n".join([m["content"] for m in messages if m["role"] in ("developer", "system")]) | |
| def get_last_user_message(messages: Messages, include_buckets: bool = True) -> str: | |
| user_messages = [] | |
| for message in messages[::-1]: | |
| if message.get("role") == "user" or not user_messages: | |
| if message.get("role") != "user": | |
| continue | |
| content = message.get("content") | |
| if include_buckets: | |
| content = to_string(content).strip() | |
| if isinstance(content, str): | |
| user_messages.append(content) | |
| else: | |
| for content_item in content: | |
| if content_item.get("type") == "text": | |
| content = content_item.get("text").strip() | |
| if content: | |
| user_messages.append(content) | |
| else: | |
| return "\n".join(user_messages[::-1]) | |
| return "\n".join(user_messages[::-1]) | |
| def get_last_message(messages: Messages, prompt: str = None) -> str: | |
| if prompt is None: | |
| for message in messages[::-1]: | |
| content = to_string(message.get("content")).strip() | |
| if content: | |
| prompt = content | |
| return prompt | |
| def format_media_prompt(messages, prompt: str = None) -> str: | |
| if prompt is None: | |
| return get_last_user_message(messages) | |
| return prompt | |
| def format_prompt_max_length(messages: Messages, max_lenght: int) -> str: | |
| prompt = format_prompt(messages) | |
| start = len(prompt) | |
| if start > max_lenght: | |
| if len(messages) > 6: | |
| prompt = format_prompt(messages[:3] + messages[-3:]) | |
| if len(prompt) > max_lenght: | |
| if len(messages) > 2: | |
| prompt = format_prompt([m for m in messages if m["role"] == "system"] + messages[-1:]) | |
| if len(prompt) > max_lenght: | |
| prompt = messages[-1]["content"] | |
| debug.log(f"Messages trimmed from: {start} to: {len(prompt)}") | |
| return prompt | |
| def get_random_string(length: int = 10) -> str: | |
| """ | |
| Generate a random string of specified length, containing lowercase letters and digits. | |
| Args: | |
| length (int, optional): Length of the random string to generate. Defaults to 10. | |
| Returns: | |
| str: A random string of the specified length. | |
| """ | |
| return ''.join( | |
| random.choice(string.ascii_lowercase + string.digits) | |
| for _ in range(length) | |
| ) | |
| def get_random_hex(length: int = 32) -> str: | |
| """ | |
| Generate a random hexadecimal string with n length. | |
| Returns: | |
| str: A random hexadecimal string of n characters. | |
| """ | |
| return ''.join( | |
| random.choice("abcdef" + string.digits) | |
| for _ in range(length) | |
| ) | |
| def filter_none(**kwargs) -> dict: | |
| return { | |
| key: value | |
| for key, value in kwargs.items() | |
| if value is not None | |
| } | |
| async def async_concat_chunks(chunks: AsyncIterator) -> str: | |
| return concat_chunks([chunk async for chunk in chunks]) | |
| def concat_chunks(chunks: Iterator) -> str: | |
| return "".join([ | |
| str(chunk) for chunk in chunks | |
| if chunk and not isinstance(chunk, Exception) | |
| ]) | |
| def format_cookies(cookies: Cookies) -> str: | |
| return "; ".join([f"{k}={v}" for k, v in cookies.items()]) |