CHUNYU0505 commited on
Commit
06f5c87
·
verified ·
1 Parent(s): 9c1b3ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -52
app.py CHANGED
@@ -10,27 +10,10 @@ from huggingface_hub import login, snapshot_download
10
  import gradio as gr
11
 
12
  # -------------------------------
13
- # 0. 向量資料庫載入
14
- # -------------------------------
15
- EMBEDDINGS_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
16
- embeddings_model = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL_NAME)
17
-
18
- DB_PATH = "./faiss_db"
19
- if os.path.exists(DB_PATH):
20
- print("✅ 載入現有向量資料庫...")
21
- db = FAISS.load_local(DB_PATH, embeddings_model, allow_dangerous_deserialization=True)
22
- else:
23
- raise ValueError("❌ 沒找到 faiss_db,請先建立向量資料庫")
24
-
25
- retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 5})
26
-
27
- # -------------------------------
28
- # 1. 中文模型(T5 Pegasus)
29
  # -------------------------------
30
  MODEL_NAME = "imxly/t5-pegasus-small"
31
 
32
-
33
-
34
  HF_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
35
  if HF_TOKEN:
36
  login(token=HF_TOKEN)
@@ -41,7 +24,15 @@ if not os.path.exists(LOCAL_MODEL_DIR):
41
  print(f"⬇️ 嘗試下載模型 {MODEL_NAME} ...")
42
  snapshot_download(repo_id=MODEL_NAME, token=HF_TOKEN, local_dir=LOCAL_MODEL_DIR)
43
 
44
- tokenizer = AutoTokenizer.from_pretrained(LOCAL_MODEL_DIR)
 
 
 
 
 
 
 
 
45
  model = AutoModelForSeq2SeqLM.from_pretrained(LOCAL_MODEL_DIR)
46
 
47
  generator = pipeline(
@@ -56,7 +47,7 @@ def call_local_inference(prompt, max_new_tokens=256):
56
  outputs = generator(
57
  prompt,
58
  max_new_tokens=max_new_tokens,
59
- do_sample=False, # 用摘要模型 → 不建議隨機取樣
60
  temperature=0.7
61
  )
62
  return outputs[0]["generated_text"]
@@ -64,67 +55,86 @@ def call_local_inference(prompt, max_new_tokens=256):
64
  return f"(生成失敗:{e})"
65
 
66
  # -------------------------------
67
- # 2. 基於 RAG 的文章生成
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  # -------------------------------
69
- def generate_article_rag_only(query, segments=3):
70
  docx_file = "/tmp/generated_article.docx"
71
  doc = DocxDocument()
72
  doc.add_heading(query, level=1)
73
- doc.save(docx_file)
74
 
75
  all_text = []
76
 
77
- # 🔍 RAG 檢索
78
  retrieved_docs = retriever.get_relevant_documents(query)
79
  context_texts = [d.page_content for d in retrieved_docs]
80
- full_context = "\n".join(context_texts)
81
 
82
- # 切分 context,避免太長
83
- splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
84
- chunks = splitter.split_text(full_context)
85
-
86
- for i, chunk in enumerate(chunks[:segments]):
87
- progress_text = f"⏳ 正在生成第 {i+1}/{segments} 段..."
88
  prompt = (
89
- f"以下是唯一可用的參考內容:\n{chunk}\n\n"
90
- f"請基於這些內容,寫一段約150-200字的中文文章,"
91
- f"主題:{query}。\n"
92
- f"⚠️ 僅能使用參考內容,不可加入外部知識。"
93
  )
 
94
  paragraph = call_local_inference(prompt)
95
  all_text.append(paragraph)
 
96
 
97
- # 即時寫入 DOCX
98
- doc = DocxDocument(docx_file)
99
- doc.add_paragraph(f"��{i+1}段:\n{paragraph}")
100
- doc.save(docx_file)
101
 
102
- yield "\n\n".join(all_text), None, f"本次使用模型:{MODEL_NAME}", full_context, progress_text
103
-
104
- final_progress = f"✅ 已完成全部 {segments} 段生成!"
105
- yield "\n\n".join(all_text), docx_file, f"本次使用模型:{MODEL_NAME}", full_context, final_progress
106
 
107
  # -------------------------------
108
- # 3. Gradio 介面
109
  # -------------------------------
110
  with gr.Blocks() as demo:
111
- gr.Markdown("# 📺 電視弘法視頻生成文章RAG系統")
112
- gr.Markdown("只基於 faiss_db 內容生成中文文章。")
113
 
114
  query_input = gr.Textbox(lines=2, placeholder="請輸入文章主題", label="文章主題")
115
- segments_input = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="段落數")
116
  output_text = gr.Textbox(label="生成文章")
117
  output_file = gr.File(label="下載 DOCX")
118
- model_used_text = gr.Textbox(label="實際使用模型", interactive=False)
119
- context_text = gr.Textbox(label="檢索到的內容", interactive=False, lines=6)
120
- progress_text = gr.Textbox(label="生成進度", interactive=False)
121
 
122
  btn = gr.Button("生成文章")
123
  btn.click(
124
- generate_article_rag_only,
125
  inputs=[query_input, segments_input],
126
- outputs=[output_text, output_file, model_used_text, context_text, progress_text]
127
  )
128
 
 
 
 
129
  if __name__ == "__main__":
130
  demo.launch()
 
10
  import gradio as gr
11
 
12
  # -------------------------------
13
+ # 1. 模型設定(專門中文 T5 Pegasus)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  # -------------------------------
15
  MODEL_NAME = "imxly/t5-pegasus-small"
16
 
 
 
17
  HF_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
18
  if HF_TOKEN:
19
  login(token=HF_TOKEN)
 
24
  print(f"⬇️ 嘗試下載模型 {MODEL_NAME} ...")
25
  snapshot_download(repo_id=MODEL_NAME, token=HF_TOKEN, local_dir=LOCAL_MODEL_DIR)
26
 
27
+ print(f"👉 最終使用模型:{MODEL_NAME}")
28
+
29
+ # -------------------------------
30
+ # 2. pipeline 載入
31
+ # -------------------------------
32
+ tokenizer = AutoTokenizer.from_pretrained(
33
+ LOCAL_MODEL_DIR,
34
+ use_fast=False # ✅ 避免 tiktoken 錯誤
35
+ )
36
  model = AutoModelForSeq2SeqLM.from_pretrained(LOCAL_MODEL_DIR)
37
 
38
  generator = pipeline(
 
47
  outputs = generator(
48
  prompt,
49
  max_new_tokens=max_new_tokens,
50
+ do_sample=True,
51
  temperature=0.7
52
  )
53
  return outputs[0]["generated_text"]
 
55
  return f"(生成失敗:{e})"
56
 
57
  # -------------------------------
58
+ # 3. 建立或載入向量資料庫
59
+ # -------------------------------
60
+ TXT_FOLDER = "./out_texts"
61
+ DB_PATH = "./faiss_db"
62
+ os.makedirs(DB_PATH, exist_ok=True)
63
+ os.makedirs(TXT_FOLDER, exist_ok=True)
64
+
65
+ EMBEDDINGS_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
66
+ embeddings_model = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL_NAME)
67
+
68
+ if os.path.exists(os.path.join(DB_PATH, "index.faiss")):
69
+ print("✅ 載入現有向量資料庫...")
70
+ db = FAISS.load_local(DB_PATH, embeddings_model, allow_dangerous_deserialization=True)
71
+ else:
72
+ print("⚠️ 沒有資料庫,建立新向量資料庫...")
73
+ docs = []
74
+ txt_files = [f for f in os.listdir(TXT_FOLDER) if f.endswith(".txt")]
75
+ for filename in txt_files:
76
+ with open(os.path.join(TXT_FOLDER, filename), "r", encoding="utf-8") as f:
77
+ docs.append(Document(page_content=f.read(), metadata={"source": filename}))
78
+ splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
79
+ split_docs = splitter.split_documents(docs)
80
+ db = FAISS.from_documents(split_docs, embeddings_model)
81
+ db.save_local(DB_PATH)
82
+
83
+ retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 5})
84
+
85
+ # -------------------------------
86
+ # 4. 文章生成(結合 RAG)
87
  # -------------------------------
88
+ def generate_article_progress(query, segments=5):
89
  docx_file = "/tmp/generated_article.docx"
90
  doc = DocxDocument()
91
  doc.add_heading(query, level=1)
 
92
 
93
  all_text = []
94
 
95
+ # 🔍 使用 RAG 檢索
96
  retrieved_docs = retriever.get_relevant_documents(query)
97
  context_texts = [d.page_content for d in retrieved_docs]
98
+ context = "\n".join([f"{i+1}. {txt}" for i, txt in enumerate(context_texts[:3])])
99
 
100
+ for i in range(segments):
 
 
 
 
 
101
  prompt = (
102
+ f"以下是佛教經論的相關段落:\n{context}\n\n"
103
+ f"請依據上面內容,寫一段約150-200字的中文文章,"
104
+ f"主題:{query}。\n第{i+1}段:"
 
105
  )
106
+
107
  paragraph = call_local_inference(prompt)
108
  all_text.append(paragraph)
109
+ doc.add_paragraph(paragraph)
110
 
111
+ yield "\n\n".join(all_text), None, f"本次使用模型:{MODEL_NAME}"
 
 
 
112
 
113
+ doc.save(docx_file)
114
+ yield "\n\n".join(all_text), docx_file, f"本次使用模型:{MODEL_NAME}"
 
 
115
 
116
  # -------------------------------
117
+ # 5. Gradio 介面
118
  # -------------------------------
119
  with gr.Blocks() as demo:
120
+ gr.Markdown("# 📺 電視弘法視頻生成文章 RAG 系統")
121
+ gr.Markdown("使用 FAISS + 中文 T5 模型,根據資料庫生成中文文章。")
122
 
123
  query_input = gr.Textbox(lines=2, placeholder="請輸入文章主題", label="文章主題")
124
+ segments_input = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="段落數")
125
  output_text = gr.Textbox(label="生成文章")
126
  output_file = gr.File(label="下載 DOCX")
127
+ output_model = gr.Textbox(label="使用的模型")
 
 
128
 
129
  btn = gr.Button("生成文章")
130
  btn.click(
131
+ generate_article_progress,
132
  inputs=[query_input, segments_input],
133
+ outputs=[output_text, output_file, output_model]
134
  )
135
 
136
+ # -------------------------------
137
+ # 6. 啟動
138
+ # -------------------------------
139
  if __name__ == "__main__":
140
  demo.launch()