Spaces:
Sleeping
Sleeping
| import os | |
| os.environ["HF_HOME"] = "/tmp/huggingface" | |
| from fastapi.responses import HTMLResponse | |
| from fastapi import FastAPI | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import pipeline | |
| import faiss | |
| import pickle | |
| app = FastAPI(title="RAG Chatbot API") | |
| # === Enable CORS (optional but recommended for browser-based Razor apps) === | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Replace with your domain in production | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # === Load once at startup === | |
| chunks_path = "chunks.pkl" | |
| index_path = "index.faiss" | |
| with open(chunks_path, "rb") as f: | |
| chunks = pickle.load(f) | |
| index = faiss.read_index(index_path) | |
| #embedder = SentenceTransformer("all-MiniLM-L6-v2") | |
| embedder = SentenceTransformer("local_model") | |
| #generator = pipeline("text2text-generation", model="google/flan-t5-base") | |
| #generator = pipeline("text2text-generation", model="flan-t5-base-local") | |
| generator = pipeline("text2text-generation", model="google/flan-t5-base", cache_dir="/tmp/huggingface") | |
| # === Retrieval and Generation === | |
| def retrieve(query, top_k=3): | |
| query_embedding = embedder.encode([query], convert_to_numpy=True) | |
| distances, indices = index.search(query_embedding, top_k) | |
| return [chunks[i] for i in indices[0]] | |
| def generate_answer(context, query): | |
| prompt = f"Answer the question based on the context and give meaningfull ending.\n\nContext:\n{context}\n\nQuestion: {query}" | |
| response = generator(prompt, max_new_tokens=150)[0]["generated_text"] | |
| return response.strip() | |
| # === API Endpoint === | |
| class QueryRequest(BaseModel): | |
| query: str # Must match Razor payload field | |
| def ask_question(request: QueryRequest): | |
| print(f"β Received query: {request.query}") # Debug log | |
| retrieved = retrieve(request.query) | |
| context = "\n".join(retrieved) | |
| answer = generate_answer(context, request.query) | |
| return { | |
| "UserQuery": request.query, | |
| "RetrievedContext": context, | |
| "answer": answer | |
| } | |
| from fastapi import Form | |
| def ask_question(query: str = Form(...)): | |
| retrieved = retrieve(query) | |
| context = "\n".join(retrieved) | |
| answer = generate_answer(context, query) | |
| return { | |
| "UserQuery": query, | |
| "RetrievedContext": context, | |
| "answer": answer | |
| } | |
| def home(): | |
| return """ | |
| <h2>π€ RAG Chatbot</h2> | |
| <form method="post" action="/querystr"> | |
| <input name="query" placeholder="Ask a question" style="width:300px;" /> | |
| <button type="submit">Ask</button> | |
| </form> | |
| """ | |