Spaces:
Sleeping
Sleeping
| import shutil | |
| import os | |
| import gradio as gr | |
| import torch | |
| from uuid import uuid4 | |
| #from huggingface_hub.file_download import http_get | |
| from langchain_community.document_loaders import ( | |
| CSVLoader, | |
| EverNoteLoader, | |
| PDFMinerLoader, | |
| TextLoader, | |
| UnstructuredEmailLoader, | |
| UnstructuredEPubLoader, | |
| UnstructuredHTMLLoader, | |
| UnstructuredMarkdownLoader, | |
| UnstructuredODTLoader, | |
| UnstructuredPowerPointLoader, | |
| UnstructuredWordDocumentLoader, | |
| ) | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.docstore.document import Document | |
| from sentence_transformers import SentenceTransformer | |
| from sentence_transformers.util import cos_sim | |
| import llama_cpp | |
| from llama_cpp import Llama | |
| SYSTEM_PROMPT = "Ты — русскоязычный автоматический ассистент. Ты разговариваешь с людьми и помогаешь им." | |
| LOADER_MAPPING = { | |
| ".csv": (CSVLoader, {}), | |
| ".doc": (UnstructuredWordDocumentLoader, {}), | |
| ".docx": (UnstructuredWordDocumentLoader, {}), | |
| ".enex": (EverNoteLoader, {}), | |
| ".epub": (UnstructuredEPubLoader, {}), | |
| ".html": (UnstructuredHTMLLoader, {}), | |
| ".md": (UnstructuredMarkdownLoader, {}), | |
| ".odt": (UnstructuredODTLoader, {}), | |
| ".pdf": (PDFMinerLoader, {}), | |
| ".ppt": (UnstructuredPowerPointLoader, {}), | |
| ".pptx": (UnstructuredPowerPointLoader, {}), | |
| ".txt": (TextLoader, {"encoding": "utf8"}), | |
| } | |
| def load_model( | |
| directory: str = "." | |
| ): | |
| model = Llama(model_path = "4.gguf", n_ctx = 3096, n_gpu_layers=-1, n_batch = 512, chat_format="gemma", verbose=False) | |
| print("Model loaded!") | |
| return model | |
| EMBEDDER = SentenceTransformer("deepvk/USER-bge-m3") | |
| MODEL = load_model() | |
| def get_uuid(): | |
| return str(uuid4()) | |
| def load_single_document(file_path: str) -> Document: | |
| ext = "." + file_path.rsplit(".", 1)[-1] | |
| assert ext in LOADER_MAPPING | |
| loader_class, loader_args = LOADER_MAPPING[ext] | |
| loader = loader_class(file_path, **loader_args) | |
| return loader.load()[0] | |
| def get_message_tokens(model, role, content): | |
| content = f"{role}\n{content}\n</s>" | |
| content = content.encode("utf-8") | |
| return model.tokenize(content, special=True) | |
| def get_system_tokens(model): | |
| system_message = {"role": "system", "content": SYSTEM_PROMPT} | |
| return get_message_tokens(model, **system_message) | |
| def process_text(text): | |
| lines = text.split("\n") | |
| lines = [line for line in lines if len(line.strip()) > 2] | |
| text = "\n".join(lines).strip() | |
| if len(text) < 10: | |
| return None | |
| return text | |
| def upload_files(files, file_paths): | |
| file_paths = [f.name for f in files] | |
| return file_paths | |
| def build_index(file_paths, db, chunk_size, chunk_overlap, file_warning): | |
| documents = [load_single_document(path) for path in file_paths] | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) | |
| documents = text_splitter.split_documents(documents) | |
| print("Documents after split:", len(documents)) | |
| fixed_documents = [] | |
| for doc in documents: | |
| doc.page_content = process_text(doc.page_content) | |
| if not doc.page_content: | |
| continue | |
| fixed_documents.append(doc) | |
| print("Documents after processing:", len(fixed_documents)) | |
| texts = [doc.page_content for doc in fixed_documents] | |
| embeddings = EMBEDDER.encode(texts, convert_to_tensor=True) | |
| db = {"docs": texts, "embeddings": embeddings} | |
| print("Embeddings calculated!") | |
| file_warning = f"Загружено {len(fixed_documents)} фрагментов! Можно задавать вопросы." | |
| return db, file_warning | |
| def retrieve(history, db, retrieved_docs, k_documents): | |
| retrieved_docs = "" | |
| if db: | |
| last_user_message = history[-1][0] | |
| query_embedding = EMBEDDER.encode(last_user_message, convert_to_tensor=True) | |
| scores = cos_sim(query_embedding, db["embeddings"])[0] | |
| top_k_idx = torch.topk(scores, k=k_documents)[1] | |
| top_k_documents = [db["docs"][idx] for idx in top_k_idx] | |
| retrieved_docs = "\n\n".join(top_k_documents) | |
| return retrieved_docs | |
| def user(message, history, system_prompt): | |
| new_history = history + [[message, None]] | |
| return "", new_history | |
| def bot( | |
| history, | |
| system_prompt, | |
| conversation_id, | |
| retrieved_docs, | |
| top_p, | |
| top_k, | |
| temp | |
| ): | |
| model = MODEL | |
| if not history: | |
| return | |
| tokens = get_system_tokens(model)[:] | |
| for user_message, bot_message in history[:-1]: | |
| message_tokens = get_message_tokens(model=model, role="user", content=user_message) | |
| tokens.extend(message_tokens) | |
| if bot_message: | |
| message_tokens = get_message_tokens(model=model, role="bot", content=bot_message) | |
| tokens.extend(message_tokens) | |
| last_user_message = history[-1][0] | |
| if retrieved_docs: | |
| last_user_message = f"Контекст: {retrieved_docs}\n\nИспользуя контекст, ответь на вопрос: {last_user_message}" | |
| message_tokens = get_message_tokens(model=model, role="user", content=last_user_message) | |
| tokens.extend(message_tokens) | |
| role_tokens = model.tokenize("bot\n".encode("utf-8"), special=True) | |
| tokens.extend(role_tokens) | |
| generator = model.generate( | |
| tokens, | |
| top_k=top_k, | |
| top_p=top_p, | |
| temp=temp | |
| ) | |
| partial_text = "" | |
| for i, token in enumerate(generator): | |
| if token == model.token_eos(): | |
| break | |
| partial_text += model.detokenize([token]).decode("utf-8", "ignore") | |
| history[-1][1] = partial_text | |
| yield history | |
| with gr.Blocks( | |
| theme=gr.themes.Soft() | |
| ) as demo: | |
| db = gr.State(None) | |
| conversation_id = gr.State(get_uuid) | |
| gr.Markdown( | |
| f"""<h1><center>Вопросно-ответная система по Вашим документам. Работает на CPU.\n | |
| На демо-стенде реализован простейший алгоритм.\n | |
| При внедрении в IT-контуре компании, качество поиска и ответа выше в разы.\n | |
| Для внедрения быстрой версии на GPU (ответ быстрее в 20-100 раз) в информационном контуре Вашей организации, пишите на\n | |
| e-mail: info@digital-human.ru</center></h1> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=5): | |
| file_output = gr.File(file_count="multiple", label="Загрузка файлов") | |
| file_paths = gr.State([]) | |
| file_warning = gr.Markdown(f"Фрагменты ещё не загружены!") | |
| with gr.Column(min_width=200, scale=3): | |
| with gr.Tab(label="Параметры разбивки текста"): | |
| chunk_size = gr.Slider( | |
| minimum=50, | |
| maximum=2000, | |
| value=256, | |
| step=50, | |
| interactive=True, | |
| label="Размер фрагментов", | |
| ) | |
| chunk_overlap = gr.Slider( | |
| minimum=0, | |
| maximum=500, | |
| value=32, | |
| step=10, | |
| interactive=True, | |
| label="Пересечение" | |
| ) | |
| with gr.Row(): | |
| k_documents = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=4, | |
| step=1, | |
| interactive=True, | |
| label="Кол-во фрагментов для контекста" | |
| ) | |
| with gr.Row(): | |
| retrieved_docs = gr.Textbox( | |
| lines=6, | |
| label="Извлеченные фрагменты", | |
| placeholder="Появятся после задания вопросов", | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=5): | |
| system_prompt = gr.Textbox(label="Системный промпт", placeholder="", value=SYSTEM_PROMPT, interactive=False) | |
| #chatbot = gr.Chatbot(label="Диалог").style(height=400) | |
| chatbot = gr.Chatbot(label="Диалог") | |
| with gr.Column(min_width=80, scale=1): | |
| with gr.Tab(label="Параметры генерации"): | |
| top_p = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.9, | |
| step=0.05, | |
| interactive=True, | |
| label="Top-p", | |
| ) | |
| top_k = gr.Slider( | |
| minimum=10, | |
| maximum=100, | |
| value=30, | |
| step=5, | |
| interactive=True, | |
| label="Top-k", | |
| ) | |
| temp = gr.Slider( | |
| minimum=0.0, | |
| maximum=2.0, | |
| value=0.01, | |
| step=0.1, | |
| interactive=True, | |
| label="Temp" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| msg = gr.Textbox( | |
| label="Отправить сообщение", | |
| placeholder="Отправить сообщение", | |
| show_label=False, | |
| ) | |
| #.style(container=False) | |
| with gr.Column(): | |
| with gr.Row(): | |
| submit = gr.Button("Отправить") | |
| stop = gr.Button("Остановить") | |
| clear = gr.Button("Очистить") | |
| # Upload files | |
| upload_event = file_output.change( | |
| fn=upload_files, | |
| inputs=[file_output, file_paths], | |
| outputs=[file_paths], | |
| queue=True, | |
| ).success( | |
| fn=build_index, | |
| inputs=[file_paths, db, chunk_size, chunk_overlap, file_warning], | |
| outputs=[db, file_warning], | |
| queue=True | |
| ) | |
| # Pressing Enter | |
| submit_event = msg.submit( | |
| fn=user, | |
| inputs=[msg, chatbot, system_prompt], | |
| outputs=[msg, chatbot], | |
| queue=False, | |
| ).success( | |
| fn=retrieve, | |
| inputs=[chatbot, db, retrieved_docs, k_documents], | |
| outputs=[retrieved_docs], | |
| queue=True, | |
| ).success( | |
| fn=bot, | |
| inputs=[ | |
| chatbot, | |
| system_prompt, | |
| conversation_id, | |
| retrieved_docs, | |
| top_p, | |
| top_k, | |
| temp | |
| ], | |
| outputs=chatbot, | |
| queue=True, | |
| ) | |
| # Pressing the button | |
| submit_click_event = submit.click( | |
| fn=user, | |
| inputs=[msg, chatbot, system_prompt], | |
| outputs=[msg, chatbot], | |
| queue=False, | |
| ).success( | |
| fn=retrieve, | |
| inputs=[chatbot, db, retrieved_docs, k_documents], | |
| outputs=[retrieved_docs], | |
| queue=True, | |
| ).success( | |
| fn=bot, | |
| inputs=[ | |
| chatbot, | |
| system_prompt, | |
| conversation_id, | |
| retrieved_docs, | |
| top_p, | |
| top_k, | |
| temp | |
| ], | |
| outputs=chatbot, | |
| queue=True, | |
| ) | |
| # Stop generation | |
| stop.click( | |
| fn=None, | |
| inputs=None, | |
| outputs=None, | |
| cancels=[submit_event, submit_click_event], | |
| queue=False, | |
| ) | |
| # Clear history | |
| clear.click(lambda: None, None, chatbot, queue=False) | |
| #demo.queue(max_size=128, concurrency_limit=1) | |
| demo.launch(show_error=True) | |
| #demo.launch() | |