File size: 2,374 Bytes
0e62273
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from sentence_transformers import SentenceTransformer
import faiss
import os

# 📌 1. Загружаем LLaMA 3
MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, device_map="auto")

# 📌 2. Загружаем Sentence Transformer для эмбеддингов
embedder = SentenceTransformer("all-MiniLM-L6-v2")

# 📌 3. Загружаем свою базу знаний
def load_documents():
    knowledge_base = []
    for file_name in os.listdir("files"):
        file_path = os.path.join("files", file_name)
        with open(file_path, "r", encoding="utf-8") as file:
            text = file.read()
            knowledge_base.append(text)
    return knowledge_base

documents = load_documents()
document_embeddings = embedder.encode(documents, convert_to_tensor=True)

# 📌 4. Создаем FAISS-индекс
index = faiss.IndexFlatL2(document_embeddings.shape[1])
index.add(document_embeddings.cpu().numpy())

# 📌 5. Функция поиска релевантной информации
def retrieve_relevant_info(query, top_k=2):
    query_embedding = embedder.encode([query], convert_to_tensor=True)
    query_embedding = query_embedding.cpu().numpy()
    distances, indices = index.search(query_embedding, top_k)
    retrieved_docs = [documents[idx] for idx in indices[0]]
    return " ".join(retrieved_docs)

# 📌 6. Функция генерации ответа
def generate_response(query):
    relevant_info = retrieve_relevant_info(query)
    input_text = f"Context: {relevant_info}\nQuestion: {query}\nAnswer:"
    inputs = tokenizer(input_text, return_tensors="pt").to("cuda")
    outputs = model.generate(**inputs, max_length=200)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# 📌 7. Gradio-интерфейс
interface = gr.Interface(
    fn=generate_response,
    inputs=gr.Textbox(lines=2, placeholder="Введите ваш вопрос..."),
    outputs=gr.Textbox(),
    title="RAG с LLaMA 3",
    description="Этот чатбот использует RAG (Retrieval-Augmented Generation) с LLaMA 3 и вашими документами."
)

interface.launch()