CHUNYU0505 commited on
Commit
fb13185
·
verified ·
1 Parent(s): fc6e44c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -107
app.py CHANGED
@@ -1,147 +1,68 @@
1
- # app.py
2
  # -------------------------------
3
- # 1. 套件載入
4
  # -------------------------------
5
- import os, glob, requests
6
- from langchain.docstore.document import Document
7
- from langchain.text_splitter import RecursiveCharacterTextSplitter
8
- from langchain_community.vectorstores import FAISS
9
- from langchain_huggingface import HuggingFaceEmbeddings
10
- from docx import Document as DocxDocument
11
- import gradio as gr
 
12
 
13
- # -------------------------------
14
- # 2. 環境變數與資料路徑
15
- # -------------------------------
16
- TXT_FOLDER = "./out_texts"
17
- DB_PATH = "./faiss_db"
18
- os.makedirs(DB_PATH, exist_ok=True)
19
- os.makedirs(TXT_FOLDER, exist_ok=True)
20
-
21
- HF_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
22
- if not HF_TOKEN:
23
- raise ValueError(
24
- "請在 Hugging Face Space 的 Settings → Repository secrets 設定 HUGGINGFACEHUB_API_TOKEN"
25
- )
26
-
27
- # -------------------------------
28
- # 3. 建立或載入向量資料庫
29
- # -------------------------------
30
- EMBEDDINGS_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
31
- embeddings_model = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL_NAME)
32
-
33
- if os.path.exists(os.path.join(DB_PATH, "index.faiss")):
34
- print("載入現有向量資料庫...")
35
- db = FAISS.load_local(DB_PATH, embeddings_model, allow_dangerous_deserialization=True)
36
- else:
37
- print("沒有資料庫,開始建立新向量資料庫...")
38
- txt_files = glob.glob(f"{TXT_FOLDER}/*.txt")
39
- if not txt_files:
40
- print("注意:TXT 資料夾中沒有任何文字檔,向量資料庫將為空。")
41
- docs = []
42
- for filepath in txt_files:
43
- with open(filepath, "r", encoding="utf-8") as f:
44
- docs.append(Document(page_content=f.read(), metadata={"source": os.path.basename(filepath)}))
45
- splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
46
- split_docs = splitter.split_documents(docs)
47
- db = FAISS.from_documents(split_docs, embeddings_model)
48
- db.save_local(DB_PATH)
49
-
50
- retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 5})
51
-
52
- # -------------------------------
53
- # 4. 定義 REST API 呼叫函數
54
- # -------------------------------
55
- HEADERS = {"Authorization": f"Bearer {HF_TOKEN}"}
56
-
57
- def call_hf_inference(model_name, prompt, max_new_tokens=512):
58
- api_url = f"https://api-inference.huggingface.co/models/{model_name}"
59
- payload = {
60
- "inputs": prompt,
61
- "parameters": {"max_new_tokens": max_new_tokens}
62
- }
63
- try:
64
- response = requests.post(api_url, headers=HEADERS, json=payload, timeout=180) # timeout 拉長
65
- response.raise_for_status()
66
- data = response.json()
67
- if isinstance(data, list) and "generated_text" in data[0]:
68
- return data[0]["generated_text"]
69
- elif isinstance(data, dict) and "error" in data:
70
- return f"(生成失敗:{data['error']},請嘗試換一個模型)"
71
- else:
72
- return str(data)
73
- except requests.exceptions.ReadTimeout:
74
- return "(生成失敗:等待超時,請嘗試換小一點的模型或增加 timeout 秒數)"
75
- except Exception as e:
76
- return f"(生成失敗:{e})"
77
-
78
- # -------------------------------
79
- # 5. 查詢 API 剩餘額度
80
- # -------------------------------
81
- def get_hf_rate_limit():
82
- try:
83
- r = requests.get("https://huggingface.co/api/whoami", headers=HEADERS)
84
- r.raise_for_status()
85
- data = r.json()
86
- remaining = data.get("rate_limit", {}).get("remaining", "未知")
87
- return f"本小時剩餘 API 次數:約 {remaining}"
88
- except Exception:
89
- return "無法取得 API 速率資訊"
90
-
91
- # -------------------------------
92
- # 6. 生成文章(即時進度)
93
- # -------------------------------
94
  def generate_article_progress(query, model_name, segments=5):
95
  docx_file = "/tmp/generated_article.docx"
96
  doc = DocxDocument()
97
  doc.add_heading(query, level=1)
98
 
 
 
 
 
 
 
 
99
  all_text = []
100
  prompt = f"請依據下列主題生成段落:{query}\n\n每段約150-200字。"
101
 
102
  for i in range(int(segments)):
103
- paragraph = call_hf_inference(model_name, prompt)
104
  all_text.append(paragraph)
105
  doc.add_paragraph(paragraph)
106
  prompt = f"請接續上一段生成下一段:\n{paragraph}\n\n下一段:"
107
 
108
- yield "\n\n".join(all_text), None
109
 
110
  doc.save(docx_file)
111
- rate_info = get_hf_rate_limit()
112
- yield f"{rate_info}\n\n" + "\n\n".join(all_text), docx_file
113
 
114
  # -------------------------------
115
- # 7. Gradio 介面
116
  # -------------------------------
117
  with gr.Blocks() as demo:
118
- gr.Markdown("# 佛教經論 RAG 系統 (HF API)")
119
- gr.Markdown("使用 Hugging Face REST API + FAISS RAG,生成文章並提示 API 剩餘額度。")
120
 
121
  query_input = gr.Textbox(lines=2, placeholder="請輸入文章主題", label="文章主題")
122
  model_dropdown = gr.Dropdown(
123
- choices=[
124
- "gpt2",
125
- "facebook/bart-large-cnn",
126
- "bigscience/bloom-560m",
127
- "bigscience/bloomz-560m"
128
- ],
129
- value="bigscience/bloomz-560m", # 預設比較聽得懂指令
130
  label="選擇生成模型"
131
  )
132
  segments_input = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="段落數")
133
- output_text = gr.Textbox(label="生成文章 + API 剩餘次數")
134
  output_file = gr.File(label="下載 DOCX")
 
135
 
136
  btn = gr.Button("生成文章")
137
  btn.click(
138
  generate_article_progress,
139
  inputs=[query_input, model_dropdown, segments_input],
140
- outputs=[output_text, output_file]
141
  )
142
 
143
  # -------------------------------
144
- # 8. 啟動 Gradio(Hugging Face Space 適用)
145
  # -------------------------------
146
  if __name__ == "__main__":
147
  demo.launch()
 
 
1
  # -------------------------------
2
+ # 5. 生成文章(即時進度)
3
  # -------------------------------
4
+ def pick_model_auto(segments):
5
+ """根據段落數自動挑選模型"""
6
+ if segments <= 3:
7
+ return "Gemma-2B"
8
+ elif segments <= 6:
9
+ return "BTLM-3B-8K"
10
+ else:
11
+ return "gpt-oss-20B"
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def generate_article_progress(query, model_name, segments=5):
14
  docx_file = "/tmp/generated_article.docx"
15
  doc = DocxDocument()
16
  doc.add_heading(query, level=1)
17
 
18
+ # 自動挑模型
19
+ if model_name == "Auto":
20
+ selected_model = pick_model_auto(int(segments))
21
+ else:
22
+ selected_model = model_name
23
+ print(f"👉 使用模型: {selected_model}")
24
+
25
  all_text = []
26
  prompt = f"請依據下列主題生成段落:{query}\n\n每段約150-200字。"
27
 
28
  for i in range(int(segments)):
29
+ paragraph = call_local_inference(selected_model, prompt)
30
  all_text.append(paragraph)
31
  doc.add_paragraph(paragraph)
32
  prompt = f"請接續上一段生成下一段:\n{paragraph}\n\n下一段:"
33
 
34
+ yield "\n\n".join(all_text), None, f"本次使用模型:{selected_model}"
35
 
36
  doc.save(docx_file)
37
+ yield "\n\n".join(all_text), docx_file, f"本次使用模型:{selected_model}"
 
38
 
39
  # -------------------------------
40
+ # 6. Gradio 介面
41
  # -------------------------------
42
  with gr.Blocks() as demo:
43
+ gr.Markdown("# 佛教經論 RAG 系統 (本地推論 + Auto 模型選擇)")
44
+ gr.Markdown("使用 Hugging Face Space + FAISS RAG,本地模型推論,不消耗 API 額度。")
45
 
46
  query_input = gr.Textbox(lines=2, placeholder="請輸入文章主題", label="文章主題")
47
  model_dropdown = gr.Dropdown(
48
+ choices=list(MODEL_MAP.keys()),
49
+ value="Auto", # 預設自動模式
 
 
 
 
 
50
  label="選擇生成模型"
51
  )
52
  segments_input = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="段落數")
53
+ output_text = gr.Textbox(label="生成文章")
54
  output_file = gr.File(label="下載 DOCX")
55
+ model_used_text = gr.Textbox(label="實際使用模型", interactive=False)
56
 
57
  btn = gr.Button("生成文章")
58
  btn.click(
59
  generate_article_progress,
60
  inputs=[query_input, model_dropdown, segments_input],
61
+ outputs=[output_text, output_file, model_used_text]
62
  )
63
 
64
  # -------------------------------
65
+ # 7. 啟動 Gradio
66
  # -------------------------------
67
  if __name__ == "__main__":
68
  demo.launch()