rag-chatbot / app.py
giveaccesstoall's picture
Update app.py
21ae421 verified
import os
os.environ["HF_HOME"] = "/tmp/huggingface"
from fastapi.responses import HTMLResponse
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
from transformers import pipeline
import faiss
import pickle
app = FastAPI(title="RAG Chatbot API")
# === Enable CORS (optional but recommended for browser-based Razor apps) ===
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Replace with your domain in production
allow_methods=["*"],
allow_headers=["*"],
)
# === Load once at startup ===
chunks_path = "chunks.pkl"
index_path = "index.faiss"
with open(chunks_path, "rb") as f:
chunks = pickle.load(f)
index = faiss.read_index(index_path)
#embedder = SentenceTransformer("all-MiniLM-L6-v2")
embedder = SentenceTransformer("local_model")
#generator = pipeline("text2text-generation", model="google/flan-t5-base")
#generator = pipeline("text2text-generation", model="flan-t5-base-local")
generator = pipeline("text2text-generation", model="google/flan-t5-base", cache_dir="/tmp/huggingface")
# === Retrieval and Generation ===
def retrieve(query, top_k=3):
query_embedding = embedder.encode([query], convert_to_numpy=True)
distances, indices = index.search(query_embedding, top_k)
return [chunks[i] for i in indices[0]]
def generate_answer(context, query):
prompt = f"Answer the question based on the context and give meaningfull ending.\n\nContext:\n{context}\n\nQuestion: {query}"
response = generator(prompt, max_new_tokens=150)[0]["generated_text"]
return response.strip()
# === API Endpoint ===
class QueryRequest(BaseModel):
query: str # Must match Razor payload field
@app.post("/query")
def ask_question(request: QueryRequest):
print(f"βœ… Received query: {request.query}") # Debug log
retrieved = retrieve(request.query)
context = "\n".join(retrieved)
answer = generate_answer(context, request.query)
return {
"UserQuery": request.query,
"RetrievedContext": context,
"answer": answer
}
from fastapi import Form
@app.post("/querystr")
def ask_question(query: str = Form(...)):
retrieved = retrieve(query)
context = "\n".join(retrieved)
answer = generate_answer(context, query)
return {
"UserQuery": query,
"RetrievedContext": context,
"answer": answer
}
@app.get("/", response_class=HTMLResponse)
def home():
return """
<h2>πŸ€– RAG Chatbot</h2>
<form method="post" action="/querystr">
<input name="query" placeholder="Ask a question" style="width:300px;" />
<button type="submit">Ask</button>
</form>
"""