|
|
import os |
|
|
from llama_cpp import Llama |
|
|
from sentence_transformers import SentenceTransformer |
|
|
import chromadb |
|
|
from chromadb.utils import embedding_functions |
|
|
from fastapi import FastAPI, Query |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
model = Llama.from_pretrained( |
|
|
repo_id="openbmb/MiniCPM-V-2_6-gguf", |
|
|
filename="*.gguf", |
|
|
n_ctx=4096, |
|
|
) |
|
|
|
|
|
embedder = SentenceTransformer("all-MiniLM-L6-v2") |
|
|
client = chromadb.PersistentClient(path="chroma_db") |
|
|
col = client.get_or_create_collection( |
|
|
"docs", |
|
|
embedding_function=embedding_functions.SentenceTransformerEmbeddingFunction( |
|
|
model_name="all-MiniLM-L6-v2" |
|
|
) |
|
|
) |
|
|
|
|
|
seed_texts = [ |
|
|
"MiniCPM‑V‑2_6‑gguf runs well on CPU via llama.cpp.", |
|
|
"This model supports RAG with Chromadb and FastAPI + Gradio UI." |
|
|
] |
|
|
for t in seed_texts: |
|
|
col.add(documents=[t], ids=[str(hash(t))]) |
|
|
|
|
|
def rag_query(q: str) -> str: |
|
|
results = col.query( |
|
|
query_embeddings=[embedder.encode(q)], |
|
|
n_results=3 |
|
|
) |
|
|
context = "\n".join(results["documents"][0]) |
|
|
prompt = f"Context:\n{context}\n\nUser: {q}\nAssistant:" |
|
|
out = model.create(prompt=prompt, max_tokens=256, temperature=0.7) |
|
|
return out["choices"][0]["text"] |
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
@app.get("/ask") |
|
|
def ask(q: str = Query(...)): |
|
|
return {"answer": rag_query(q)} |
|
|
|
|
|
@app.post("/ask") |
|
|
def ask_post(body: dict): |
|
|
return ask(q=body.get("q","")) |
|
|
|
|
|
|
|
|
def chat_fn(message, history): |
|
|
reply = rag_query(message) |
|
|
history = history or [] |
|
|
history.append(("User", message)) |
|
|
history.append(("Assistant", reply)) |
|
|
return history, history |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
chatbot = gr.Chatbot() |
|
|
txt = gr.Textbox(placeholder="Ask me...", show_label=False) |
|
|
txt.submit(chat_fn, [txt, chatbot], [chatbot, chatbot]) |
|
|
gr.Markdown("### 🧠 MiniCPM‑V‑2_6‑gguf RAG Chat (GET & POST API support)") |
|
|
|
|
|
@app.on_event("startup") |
|
|
def startup(): |
|
|
demo.queue().launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT",7860))) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|
|