CHUNYU0505 commited on
Commit
6d8dd36
·
verified ·
1 Parent(s): 9afdb2d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -39
app.py CHANGED
@@ -10,10 +10,10 @@ from huggingface_hub import login, snapshot_download
10
  import gradio as gr
11
 
12
  # -------------------------------
13
- # 1. 模型設定(中文 T5 / Pegasus
14
  # -------------------------------
15
- PRIMARY_MODEL = "imxly/t5-pegasus-small" # 適合中文摘要/生成
16
- FALLBACK_MODEL = "uer/gpt2-chinese-cluecorpussmall" # 若 T5 無法下載就 fallback GPT2
17
 
18
  HF_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
19
  if HF_TOKEN:
@@ -31,9 +31,10 @@ def try_download_model(repo_id):
31
  return None
32
  return local_dir
33
 
 
34
  LOCAL_MODEL_DIR = try_download_model(PRIMARY_MODEL)
35
  if LOCAL_MODEL_DIR is None:
36
- print("⚠️ 切換到 fallback 模型:小型 GPT2-Chinese")
37
  LOCAL_MODEL_DIR = try_download_model(FALLBACK_MODEL)
38
  MODEL_NAME = FALLBACK_MODEL
39
  else:
@@ -42,28 +43,16 @@ else:
42
  print(f"👉 最終使用模型:{MODEL_NAME}")
43
 
44
  # -------------------------------
45
- # 2. pipeline 載入
46
  # -------------------------------
47
- tokenizer = AutoTokenizer.from_pretrained(
48
- LOCAL_MODEL_DIR,
49
- use_fast=False # 防止 sentencepiece 問題
50
- )
51
-
52
- # 判斷 GPU (CL3) 或 CPU
53
- device = 0 if torch.cuda.is_available() else -1
54
- print(f"💻 使用裝置:{'GPU' if device == 0 else 'CPU'}")
55
-
56
- try:
57
- model = AutoModelForSeq2SeqLM.from_pretrained(LOCAL_MODEL_DIR)
58
- except:
59
- from transformers import AutoModelForCausalLM
60
- model = AutoModelForCausalLM.from_pretrained(LOCAL_MODEL_DIR)
61
 
62
  generator = pipeline(
63
- "text2text-generation" if "t5" in MODEL_NAME or "pegasus" in MODEL_NAME else "text-generation",
64
  model=model,
65
  tokenizer=tokenizer,
66
- device=device
67
  )
68
 
69
  def call_local_inference(prompt, max_new_tokens=256):
@@ -79,68 +68,69 @@ def call_local_inference(prompt, max_new_tokens=256):
79
  return f"(生成失敗:{e})"
80
 
81
  # -------------------------------
82
- # 3. FAISS 向量資料庫載入
83
  # -------------------------------
84
- DB_PATH = "./faiss_db"
85
  EMBEDDINGS_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
86
  embeddings_model = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL_NAME)
87
 
 
88
  if os.path.exists(os.path.join(DB_PATH, "index.faiss")):
89
  print("✅ 載入現有向量資料庫...")
90
  db = FAISS.load_local(DB_PATH, embeddings_model, allow_dangerous_deserialization=True)
91
  else:
92
- print("⚠️ 找不到向量資料庫,將建立空的 DB")
93
- db = FAISS.from_documents([], embeddings_model)
94
 
95
- retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 5})
96
 
97
  # -------------------------------
98
- # 4. 文章生成(結合 RAG)
99
  # -------------------------------
100
- def generate_article_progress(query, segments=3):
101
  docx_file = "/tmp/generated_article.docx"
102
  doc = DocxDocument()
103
  doc.add_heading(query, level=1)
104
 
105
  all_text = []
106
 
107
- retrieved_docs = retriever.get_relevant_documents(query)
108
- context_texts = [d.page_content for d in retrieved_docs]
109
- context = "\n".join([f"{i+1}. {txt}" for i, txt in enumerate(context_texts[:3])])
 
 
110
 
111
  for i in range(segments):
112
  prompt = (
113
- f"以下是佛教經論的相關內容:\n{context}\n\n"
114
  f"請依據上面內容,寫一段約150-200字的中文文章,"
115
  f"主題:{query}。\n第{i+1}段:"
116
  )
117
  paragraph = call_local_inference(prompt)
118
  all_text.append(paragraph)
119
  doc.add_paragraph(paragraph)
120
-
121
- yield "\n\n".join(all_text), None, f"本次使用模型:{MODEL_NAME},裝置:{'GPU' if device == 0 else 'CPU'}"
122
 
123
  doc.save(docx_file)
124
- yield "\n\n".join(all_text), docx_file, f"本次使用模型:{MODEL_NAME},裝置:{'GPU' if device == 0 else 'CPU'}"
125
 
126
  # -------------------------------
127
  # 5. Gradio 介面
128
  # -------------------------------
129
  with gr.Blocks() as demo:
130
  gr.Markdown("# 📺 電視弘法視頻生成文章 RAG 系統")
131
- gr.Markdown("使用 Hugging Face 本地模型 + FAISS RAG,僅基於資料庫生成文章。")
132
 
133
  query_input = gr.Textbox(lines=2, placeholder="請輸入文章主題", label="文章主題")
134
- segments_input = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="段落數")
135
  output_text = gr.Textbox(label="生成文章")
136
  output_file = gr.File(label="下載 DOCX")
137
- status_info = gr.Label(label="狀態")
138
 
139
  btn = gr.Button("生成文章")
140
  btn.click(
141
  generate_article_progress,
142
  inputs=[query_input, segments_input],
143
- outputs=[output_text, output_file, status_info]
144
  )
145
 
146
  if __name__ == "__main__":
 
10
  import gradio as gr
11
 
12
  # -------------------------------
13
+ # 1. 模型設定(中文 T5 + fallback
14
  # -------------------------------
15
+ PRIMARY_MODEL = "Langboat/mengzi-t5-base" # ✅ 帶 spiece.model
16
+ FALLBACK_MODEL = "uer/t5-small-chinese-cluecorpussmall"
17
 
18
  HF_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
19
  if HF_TOKEN:
 
31
  return None
32
  return local_dir
33
 
34
+ # 嘗試下載 Primary,失敗就換 Small
35
  LOCAL_MODEL_DIR = try_download_model(PRIMARY_MODEL)
36
  if LOCAL_MODEL_DIR is None:
37
+ print("⚠️ 切換到 fallback 模型:小型 T5-Chinese")
38
  LOCAL_MODEL_DIR = try_download_model(FALLBACK_MODEL)
39
  MODEL_NAME = FALLBACK_MODEL
40
  else:
 
43
  print(f"👉 最終使用模型:{MODEL_NAME}")
44
 
45
  # -------------------------------
46
+ # 2. pipeline 載入 (Seq2SeqLM for T5)
47
  # -------------------------------
48
+ tokenizer = AutoTokenizer.from_pretrained(LOCAL_MODEL_DIR)
49
+ model = AutoModelForSeq2SeqLM.from_pretrained(LOCAL_MODEL_DIR)
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  generator = pipeline(
52
+ "text2text-generation",
53
  model=model,
54
  tokenizer=tokenizer,
55
+ device=-1 # CPU
56
  )
57
 
58
  def call_local_inference(prompt, max_new_tokens=256):
 
68
  return f"(生成失敗:{e})"
69
 
70
  # -------------------------------
71
+ # 3. 建立/載入向量資料庫
72
  # -------------------------------
 
73
  EMBEDDINGS_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
74
  embeddings_model = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL_NAME)
75
 
76
+ DB_PATH = "./faiss_db"
77
  if os.path.exists(os.path.join(DB_PATH, "index.faiss")):
78
  print("✅ 載入現有向量資料庫...")
79
  db = FAISS.load_local(DB_PATH, embeddings_model, allow_dangerous_deserialization=True)
80
  else:
81
+ print("⚠️ 沒有找到資料庫,請先建立 faiss_db")
82
+ db = None
83
 
84
+ retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 3}) if db else None
85
 
86
  # -------------------------------
87
+ # 4. 文章生成(加入 RAG)
88
  # -------------------------------
89
+ def generate_article_progress(query, segments=5):
90
  docx_file = "/tmp/generated_article.docx"
91
  doc = DocxDocument()
92
  doc.add_heading(query, level=1)
93
 
94
  all_text = []
95
 
96
+ context = ""
97
+ if retriever:
98
+ retrieved_docs = retriever.get_relevant_documents(query)
99
+ context_texts = [d.page_content for d in retrieved_docs]
100
+ context = "\n".join([f"{i+1}. {txt}" for i, txt in enumerate(context_texts[:3])])
101
 
102
  for i in range(segments):
103
  prompt = (
104
+ f"以下是佛教經論的相關段落:\n{context}\n\n"
105
  f"請依據上面內容,寫一段約150-200字的中文文章,"
106
  f"主題:{query}。\n第{i+1}段:"
107
  )
108
  paragraph = call_local_inference(prompt)
109
  all_text.append(paragraph)
110
  doc.add_paragraph(paragraph)
111
+ yield "\n\n".join(all_text), None, f"本次使用模型:{MODEL_NAME}"
 
112
 
113
  doc.save(docx_file)
114
+ yield "\n\n".join(all_text), docx_file, f"本次使用模型:{MODEL_NAME}"
115
 
116
  # -------------------------------
117
  # 5. Gradio 介面
118
  # -------------------------------
119
  with gr.Blocks() as demo:
120
  gr.Markdown("# 📺 電視弘法視頻生成文章 RAG 系統")
121
+ gr.Markdown("使用 FAISS + 中文 T5 模型,基於資料庫內容生成文章。")
122
 
123
  query_input = gr.Textbox(lines=2, placeholder="請輸入文章主題", label="文章主題")
124
+ segments_input = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="段落數")
125
  output_text = gr.Textbox(label="生成文章")
126
  output_file = gr.File(label="下載 DOCX")
127
+ model_info = gr.Label(label="模型資訊")
128
 
129
  btn = gr.Button("生成文章")
130
  btn.click(
131
  generate_article_progress,
132
  inputs=[query_input, segments_input],
133
+ outputs=[output_text, output_file, model_info]
134
  )
135
 
136
  if __name__ == "__main__":