Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| from sentence_transformers import SentenceTransformer | |
| import faiss | |
| import os | |
| # 📌 1. Загружаем LLaMA 3 | |
| MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, device_map="auto") | |
| # 📌 2. Загружаем Sentence Transformer для эмбеддингов | |
| embedder = SentenceTransformer("all-MiniLM-L6-v2") | |
| # 📌 3. Загружаем свою базу знаний | |
| def load_documents(): | |
| knowledge_base = [] | |
| for file_name in os.listdir("files"): | |
| file_path = os.path.join("files", file_name) | |
| with open(file_path, "r", encoding="utf-8") as file: | |
| text = file.read() | |
| knowledge_base.append(text) | |
| return knowledge_base | |
| documents = load_documents() | |
| document_embeddings = embedder.encode(documents, convert_to_tensor=True) | |
| # 📌 4. Создаем FAISS-индекс | |
| index = faiss.IndexFlatL2(document_embeddings.shape[1]) | |
| index.add(document_embeddings.cpu().numpy()) | |
| # 📌 5. Функция поиска релевантной информации | |
| def retrieve_relevant_info(query, top_k=2): | |
| query_embedding = embedder.encode([query], convert_to_tensor=True) | |
| query_embedding = query_embedding.cpu().numpy() | |
| distances, indices = index.search(query_embedding, top_k) | |
| retrieved_docs = [documents[idx] for idx in indices[0]] | |
| return " ".join(retrieved_docs) | |
| # 📌 6. Функция генерации ответа | |
| def generate_response(query): | |
| relevant_info = retrieve_relevant_info(query) | |
| input_text = f"Context: {relevant_info}\nQuestion: {query}\nAnswer:" | |
| inputs = tokenizer(input_text, return_tensors="pt").to("cuda") | |
| outputs = model.generate(**inputs, max_length=200) | |
| return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # 📌 7. Gradio-интерфейс | |
| interface = gr.Interface( | |
| fn=generate_response, | |
| inputs=gr.Textbox(lines=2, placeholder="Введите ваш вопрос..."), | |
| outputs=gr.Textbox(), | |
| title="RAG с LLaMA 3", | |
| description="Этот чатбот использует RAG (Retrieval-Augmented Generation) с LLaMA 3 и вашими документами." | |
| ) | |
| interface.launch() | |