| from __future__ import annotations | |
| import os | |
| from gpt4all import GPT4All | |
| from .models import get_models | |
| from ..typing import Messages | |
| MODEL_LIST: dict[str, dict] = None | |
| def find_model_dir(model_file: str) -> str: | |
| local_dir = os.path.dirname(os.path.abspath(__file__)) | |
| project_dir = os.path.dirname(os.path.dirname(local_dir)) | |
| new_model_dir = os.path.join(project_dir, "models") | |
| new_model_file = os.path.join(new_model_dir, model_file) | |
| if os.path.isfile(new_model_file): | |
| return new_model_dir | |
| old_model_dir = os.path.join(local_dir, "models") | |
| old_model_file = os.path.join(old_model_dir, model_file) | |
| if os.path.isfile(old_model_file): | |
| return old_model_dir | |
| working_dir = "./" | |
| for root, dirs, files in os.walk(working_dir): | |
| if model_file in files: | |
| return root | |
| return new_model_dir | |
| class LocalProvider: | |
| def create_completion(model: str, messages: Messages, stream: bool = False, **kwargs): | |
| global MODEL_LIST | |
| if MODEL_LIST is None: | |
| MODEL_LIST = get_models() | |
| if model not in MODEL_LIST: | |
| raise ValueError(f'Model "{model}" not found / not yet implemented') | |
| model = MODEL_LIST[model] | |
| model_file = model["path"] | |
| model_dir = find_model_dir(model_file) | |
| if not os.path.isfile(os.path.join(model_dir, model_file)): | |
| print(f'Model file "models/{model_file}" not found.') | |
| download = input(f"Do you want to download {model_file}? [y/n]: ") | |
| if download in ["y", "Y"]: | |
| GPT4All.download_model(model_file, model_dir) | |
| else: | |
| raise ValueError(f'Model "{model_file}" not found.') | |
| model = GPT4All(model_name=model_file, | |
| #n_threads=8, | |
| verbose=False, | |
| allow_download=False, | |
| model_path=model_dir) | |
| system_message = "\n".join(message["content"] for message in messages if message["role"] == "system") | |
| if system_message: | |
| system_message = "A chat between a curious user and an artificial intelligence assistant." | |
| prompt_template = "USER: {0}\nASSISTANT: " | |
| conversation = "\n" . join( | |
| f"{message['role'].upper()}: {message['content']}" | |
| for message in messages | |
| if message["role"] != "system" | |
| ) + "\nASSISTANT: " | |
| def should_not_stop(token_id: int, token: str): | |
| return "USER" not in token | |
| with model.chat_session(system_message, prompt_template): | |
| if stream: | |
| for token in model.generate(conversation, streaming=True, callback=should_not_stop): | |
| yield token | |
| else: | |
| yield model.generate(conversation, callback=should_not_stop) |