shawno commited on
Commit
66f4f20
·
verified ·
1 Parent(s): 205abbc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -53
app.py CHANGED
@@ -5,9 +5,11 @@ import chromadb
5
  from chromadb.utils import embedding_functions
6
  from fastapi import FastAPI, Query
7
  import gradio as gr
 
8
 
9
- # === Globals ===
10
- TOKEN_LIMIT = 256 # Default, overridden by slider
 
11
 
12
  # === Load LLM ===
13
  model = Llama.from_pretrained(
@@ -17,39 +19,54 @@ model = Llama.from_pretrained(
17
  n_threads=os.cpu_count(),
18
  )
19
 
20
- # === RAG Setup ===
21
- """embedder = SentenceTransformer("all-MiniLM-L6-v2")
22
- client = chromadb.PersistentClient(path="chroma_db")
23
- col = client.get_or_create_collection(
24
- "docs",
25
- embedding_function=embedding_functions.SentenceTransformerEmbeddingFunction(
26
- model_name="all-MiniLM-L6-v2"
 
 
27
  )
28
- )
29
- seed_texts = [
30
- "MiniCPM‑V‑2_6‑gguf runs well on CPU via llama.cpp.",
31
- "This model supports RAG with Chromadb and FastAPI + Gradio UI."
32
- ]
33
- for t in seed_texts:
34
- col.add(documents=[t], ids=[str(hash(t))])"""
35
-
36
- # === Query Function ===
 
 
 
 
 
 
 
37
  def rag_query(q: str, max_tokens: int) -> str:
38
- """results = col.query(
39
- query_embeddings=[embedder.encode(q)],
40
- n_results=3
41
- )"""
42
- #context = "\n".join(results["documents"][0])
43
- context=""
44
- prompt = f"Context:\n{context}\n\nUser: {q}\nAssistant:"
45
- out = model.create_completion(prompt=prompt, max_tokens=max_tokens, temperature=0.7)
46
- return out["choices"][0]["text"]
47
-
48
- # === FastAPI App ===
 
 
 
 
 
49
  app = FastAPI()
50
 
51
  @app.get("/ask")
52
  def ask(q: str = Query(...), tokens: int = Query(TOKEN_LIMIT)):
 
53
  return {"answer": rag_query(q, tokens)}
54
 
55
  @app.post("/ask")
@@ -59,36 +76,18 @@ def ask_post(body: dict):
59
  # === Gradio UI ===
60
  def chat_fn(message, history, max_tokens):
61
  history = history or []
62
- """history.append(gr.ChatMessage(role="user",
63
- content=message))"""
64
- new_history = history + [gr.ChatMessage(role="user", content=message)]
65
-
66
- yield new_history, new_history, "" # Show user's message immediately
67
- #reply = rag_query(message, max_tokens)
68
-
69
- new_history.append(gr.ChatMessage(role="assistant", content="reply"))
70
-
71
  yield new_history, new_history, ""
72
- """history.append(gr.ChatMessage(role="assistant",
73
- content=reply))"""
74
- #history.append((f"🧑 You", message))
75
- #history.append((f"🤖 Bot", reply))
76
- #return history, history, ""
77
 
78
  with gr.Blocks() as demo:
79
  gr.Markdown("### 🧠 MiniCPM‑V‑2_6‑gguf RAG Chat")
80
 
81
- chatbot = gr.Chatbot(type="messages", label="Bella Lite", autoscroll=True, resizable=True, show_copy_button=True)
82
- """with gr.Row():
83
- txt = gr.Textbox(placeholder="Ask me...", show_label=False, scale=8)
84
- send_btn = gr.Button("Send", scale=1)"""
85
-
86
- txt = gr.Textbox(placeholder="Ask me...", show_label=False, submit_btn="Ask")
87
-
88
- token_slider = gr.Slider(64, 1024, value=256, step=16, label="Max tokens")
89
 
90
  txt.submit(chat_fn, [txt, chatbot, token_slider], [chatbot, chatbot, txt])
91
- #send_btn.click(chat_fn, [txt, chatbot, token_slider], [chatbot, chatbot, txt])
92
 
93
  @app.on_event("startup")
94
  def startup():
@@ -96,4 +95,4 @@ def startup():
96
 
97
  if __name__ == "__main__":
98
  import uvicorn
99
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
5
  from chromadb.utils import embedding_functions
6
  from fastapi import FastAPI, Query
7
  import gradio as gr
8
+ from functools import lru_cache
9
 
10
+ # === Config ===
11
+ TOKEN_LIMIT = 256
12
+ USE_RAG = True # Toggle RAG mode
13
 
14
  # === Load LLM ===
15
  model = Llama.from_pretrained(
 
19
  n_threads=os.cpu_count(),
20
  )
21
 
22
+ # === Optional: RAG Setup ===
23
+ if USE_RAG:
24
+ embedder = SentenceTransformer("all-MiniLM-L6-v2")
25
+ client = chromadb.PersistentClient(path="chroma_db")
26
+ col = client.get_or_create_collection(
27
+ "docs",
28
+ embedding_function=embedding_functions.SentenceTransformerEmbeddingFunction(
29
+ model_name="all-MiniLM-L6-v2"
30
+ )
31
  )
32
+
33
+ seed_texts = [
34
+ "MiniCPM‑V‑2_6‑gguf runs well on CPU via llama.cpp.",
35
+ "This model supports RAG with Chromadb and FastAPI + Gradio UI."
36
+ ]
37
+ for t in seed_texts:
38
+ try:
39
+ col.add(documents=[t], ids=[str(hash(t))])
40
+ except:
41
+ pass # Avoid duplicates on restart
42
+
43
+ @lru_cache(maxsize=128)
44
+ def embed_query(q: str):
45
+ return embedder.encode(q)
46
+
47
+ # === RAG or Vanilla Query ===
48
  def rag_query(q: str, max_tokens: int) -> str:
49
+ try:
50
+ context = ""
51
+ if USE_RAG:
52
+ results = col.query(
53
+ query_embeddings=[embed_query(q)],
54
+ n_results=3
55
+ )
56
+ context = "\n".join(results["documents"][0])
57
+
58
+ prompt = f"Context:\n{context}\n\nUser: {q}\nAssistant:"
59
+ out = model.create_completion(prompt=prompt, max_tokens=max_tokens, temperature=0.7)
60
+ return out["choices"][0]["text"]
61
+ except Exception as e:
62
+ return f"[Error] {e}"
63
+
64
+ # === FastAPI ===
65
  app = FastAPI()
66
 
67
  @app.get("/ask")
68
  def ask(q: str = Query(...), tokens: int = Query(TOKEN_LIMIT)):
69
+ tokens = min(max(32, tokens), 1024)
70
  return {"answer": rag_query(q, tokens)}
71
 
72
  @app.post("/ask")
 
76
  # === Gradio UI ===
77
  def chat_fn(message, history, max_tokens):
78
  history = history or []
79
+ reply = rag_query(message, max_tokens)
80
+ new_history = history + [(message, reply)]
 
 
 
 
 
 
 
81
  yield new_history, new_history, ""
 
 
 
 
 
82
 
83
  with gr.Blocks() as demo:
84
  gr.Markdown("### 🧠 MiniCPM‑V‑2_6‑gguf RAG Chat")
85
 
86
+ chatbot = gr.Chatbot(label="Bella Lite", show_copy_button=True)
87
+ txt = gr.Textbox(placeholder="Ask me anything...", show_label=False)
88
+ token_slider = gr.Slider(64, 1024, value=TOKEN_LIMIT, step=16, label="Max Tokens")
 
 
 
 
 
89
 
90
  txt.submit(chat_fn, [txt, chatbot, token_slider], [chatbot, chatbot, txt])
 
91
 
92
  @app.on_event("startup")
93
  def startup():
 
95
 
96
  if __name__ == "__main__":
97
  import uvicorn
98
+ uvicorn.run(app, host="0.0.0.0", port=7860)