gradio / app.py
mariiapaik's picture
Create app.py
0e62273 verified
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from sentence_transformers import SentenceTransformer
import faiss
import os
# 📌 1. Загружаем LLaMA 3
MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, device_map="auto")
# 📌 2. Загружаем Sentence Transformer для эмбеддингов
embedder = SentenceTransformer("all-MiniLM-L6-v2")
# 📌 3. Загружаем свою базу знаний
def load_documents():
knowledge_base = []
for file_name in os.listdir("files"):
file_path = os.path.join("files", file_name)
with open(file_path, "r", encoding="utf-8") as file:
text = file.read()
knowledge_base.append(text)
return knowledge_base
documents = load_documents()
document_embeddings = embedder.encode(documents, convert_to_tensor=True)
# 📌 4. Создаем FAISS-индекс
index = faiss.IndexFlatL2(document_embeddings.shape[1])
index.add(document_embeddings.cpu().numpy())
# 📌 5. Функция поиска релевантной информации
def retrieve_relevant_info(query, top_k=2):
query_embedding = embedder.encode([query], convert_to_tensor=True)
query_embedding = query_embedding.cpu().numpy()
distances, indices = index.search(query_embedding, top_k)
retrieved_docs = [documents[idx] for idx in indices[0]]
return " ".join(retrieved_docs)
# 📌 6. Функция генерации ответа
def generate_response(query):
relevant_info = retrieve_relevant_info(query)
input_text = f"Context: {relevant_info}\nQuestion: {query}\nAnswer:"
inputs = tokenizer(input_text, return_tensors="pt").to("cuda")
outputs = model.generate(**inputs, max_length=200)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# 📌 7. Gradio-интерфейс
interface = gr.Interface(
fn=generate_response,
inputs=gr.Textbox(lines=2, placeholder="Введите ваш вопрос..."),
outputs=gr.Textbox(),
title="RAG с LLaMA 3",
description="Этот чатбот использует RAG (Retrieval-Augmented Generation) с LLaMA 3 и вашими документами."
)
interface.launch()