CHUNYU0505 commited on
Commit
8f7234f
·
verified ·
1 Parent(s): 49bfc77

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -38
app.py CHANGED
@@ -1,5 +1,4 @@
1
- # app.py
2
- import os
3
  from langchain.docstore.document import Document
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
  from langchain_community.vectorstores import FAISS
@@ -10,15 +9,16 @@ from huggingface_hub import login, snapshot_download
10
  import gradio as gr
11
 
12
  # -------------------------------
13
- # 1. 模型設定(專門中文 T5 Pegasus
14
  # -------------------------------
15
- MODEL_NAME = "imxly/t5-pegasus-small"
16
 
17
  HF_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
18
  if HF_TOKEN:
19
  login(token=HF_TOKEN)
20
  print("✅ 已使用 HUGGINGFACEHUB_API_TOKEN 登入 Hugging Face")
21
 
 
22
  LOCAL_MODEL_DIR = f"./models/{MODEL_NAME.split('/')[-1]}"
23
  if not os.path.exists(LOCAL_MODEL_DIR):
24
  print(f"⬇️ 嘗試下載模型 {MODEL_NAME} ...")
@@ -31,12 +31,12 @@ print(f"👉 最終使用模型:{MODEL_NAME}")
31
  # -------------------------------
32
  tokenizer = AutoTokenizer.from_pretrained(
33
  LOCAL_MODEL_DIR,
34
- use_fast=False # ✅ 避免 tiktoken 錯誤
35
  )
36
  model = AutoModelForSeq2SeqLM.from_pretrained(LOCAL_MODEL_DIR)
37
 
38
  generator = pipeline(
39
- "text2text-generation",
40
  model=model,
41
  tokenizer=tokenizer,
42
  device=-1 # CPU
@@ -44,6 +44,9 @@ generator = pipeline(
44
 
45
  def call_local_inference(prompt, max_new_tokens=256):
46
  try:
 
 
 
47
  outputs = generator(
48
  prompt,
49
  max_new_tokens=max_new_tokens,
@@ -55,13 +58,9 @@ def call_local_inference(prompt, max_new_tokens=256):
55
  return f"(生成失敗:{e})"
56
 
57
  # -------------------------------
58
- # 3. 建立或載入向量資料庫
59
  # -------------------------------
60
- TXT_FOLDER = "./out_texts"
61
  DB_PATH = "./faiss_db"
62
- os.makedirs(DB_PATH, exist_ok=True)
63
- os.makedirs(TXT_FOLDER, exist_ok=True)
64
-
65
  EMBEDDINGS_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
66
  embeddings_model = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL_NAME)
67
 
@@ -69,41 +68,34 @@ if os.path.exists(os.path.join(DB_PATH, "index.faiss")):
69
  print("✅ 載入現有向量資料庫...")
70
  db = FAISS.load_local(DB_PATH, embeddings_model, allow_dangerous_deserialization=True)
71
  else:
72
- print("⚠️ 沒有資料庫,建立新向量資料庫...")
73
- docs = []
74
- txt_files = [f for f in os.listdir(TXT_FOLDER) if f.endswith(".txt")]
75
- for filename in txt_files:
76
- with open(os.path.join(TXT_FOLDER, filename), "r", encoding="utf-8") as f:
77
- docs.append(Document(page_content=f.read(), metadata={"source": filename}))
78
- splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
79
- split_docs = splitter.split_documents(docs)
80
- db = FAISS.from_documents(split_docs, embeddings_model)
81
- db.save_local(DB_PATH)
82
-
83
- retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 5})
84
 
85
  # -------------------------------
86
  # 4. 文章生成(結合 RAG)
87
  # -------------------------------
88
- def generate_article_progress(query, segments=5):
89
  docx_file = "/tmp/generated_article.docx"
90
  doc = DocxDocument()
91
  doc.add_heading(query, level=1)
92
 
93
  all_text = []
94
 
95
- # 🔍 使用 RAG 檢索
96
- retrieved_docs = retriever.get_relevant_documents(query)
97
- context_texts = [d.page_content for d in retrieved_docs]
98
- context = "\n".join([f"{i+1}. {txt}" for i, txt in enumerate(context_texts[:3])])
 
 
99
 
100
  for i in range(segments):
101
  prompt = (
102
- f"以下是佛教經論的相關段落:\n{context}\n\n"
103
  f"請依據上面內容,寫一段約150-200字的中文文章,"
104
  f"主題:{query}。\n第{i+1}段:"
105
  )
106
-
107
  paragraph = call_local_inference(prompt)
108
  all_text.append(paragraph)
109
  doc.add_paragraph(paragraph)
@@ -118,23 +110,18 @@ def generate_article_progress(query, segments=5):
118
  # -------------------------------
119
  with gr.Blocks() as demo:
120
  gr.Markdown("# 📺 電視弘法視頻生成文章 RAG 系統")
121
- gr.Markdown("使用 FAISS + 中文 T5 模型,根據資料庫生成中文文章。")
122
-
123
  query_input = gr.Textbox(lines=2, placeholder="請輸入文章主題", label="文章主題")
124
- segments_input = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="段落數")
125
  output_text = gr.Textbox(label="生成文章")
126
  output_file = gr.File(label="下載 DOCX")
127
- output_model = gr.Textbox(label="使用的模型")
128
 
129
  btn = gr.Button("生成文章")
130
  btn.click(
131
  generate_article_progress,
132
  inputs=[query_input, segments_input],
133
- outputs=[output_text, output_file, output_model]
134
  )
135
 
136
- # -------------------------------
137
- # 6. 啟動
138
- # -------------------------------
139
  if __name__ == "__main__":
140
  demo.launch()
 
1
+ import os, torch
 
2
  from langchain.docstore.document import Document
3
  from langchain.text_splitter import RecursiveCharacterTextSplitter
4
  from langchain_community.vectorstores import FAISS
 
9
  import gradio as gr
10
 
11
  # -------------------------------
12
+ # 1. 模型設定(中文 T5)
13
  # -------------------------------
14
+ MODEL_NAME = "Langboat/mengzi-t5-base" # ✅ 換成穩定的中文 T5
15
 
16
  HF_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
17
  if HF_TOKEN:
18
  login(token=HF_TOKEN)
19
  print("✅ 已使用 HUGGINGFACEHUB_API_TOKEN 登入 Hugging Face")
20
 
21
+ # 嘗試下載模型
22
  LOCAL_MODEL_DIR = f"./models/{MODEL_NAME.split('/')[-1]}"
23
  if not os.path.exists(LOCAL_MODEL_DIR):
24
  print(f"⬇️ 嘗試下載模型 {MODEL_NAME} ...")
 
31
  # -------------------------------
32
  tokenizer = AutoTokenizer.from_pretrained(
33
  LOCAL_MODEL_DIR,
34
+ use_fast=False # ✅ 避免 tiktoken / fast tokenizer 問題
35
  )
36
  model = AutoModelForSeq2SeqLM.from_pretrained(LOCAL_MODEL_DIR)
37
 
38
  generator = pipeline(
39
+ "text2text-generation", # ✅ Seq2Seq 用這個
40
  model=model,
41
  tokenizer=tokenizer,
42
  device=-1 # CPU
 
44
 
45
  def call_local_inference(prompt, max_new_tokens=256):
46
  try:
47
+ if "中文" not in prompt:
48
+ prompt += "\n(請用中文回答)"
49
+
50
  outputs = generator(
51
  prompt,
52
  max_new_tokens=max_new_tokens,
 
58
  return f"(生成失敗:{e})"
59
 
60
  # -------------------------------
61
+ # 3. RAG 部分:向量資料庫
62
  # -------------------------------
 
63
  DB_PATH = "./faiss_db"
 
 
 
64
  EMBEDDINGS_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
65
  embeddings_model = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL_NAME)
66
 
 
68
  print("✅ 載入現有向量資料庫...")
69
  db = FAISS.load_local(DB_PATH, embeddings_model, allow_dangerous_deserialization=True)
70
  else:
71
+ print("⚠️ 沒有找到資料庫,請先建立 faiss_db")
72
+ db = None
73
+
74
+ retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 3}) if db else None
 
 
 
 
 
 
 
 
75
 
76
  # -------------------------------
77
  # 4. 文章生成(結合 RAG)
78
  # -------------------------------
79
+ def generate_article_progress(query, segments=3):
80
  docx_file = "/tmp/generated_article.docx"
81
  doc = DocxDocument()
82
  doc.add_heading(query, level=1)
83
 
84
  all_text = []
85
 
86
+ # 🔍 從資料庫檢索
87
+ context = ""
88
+ if retriever:
89
+ retrieved_docs = retriever.get_relevant_documents(query)
90
+ context_texts = [d.page_content for d in retrieved_docs]
91
+ context = "\n".join([f"{i+1}. {txt}" for i, txt in enumerate(context_texts[:3])])
92
 
93
  for i in range(segments):
94
  prompt = (
95
+ f"以下是佛教經論的相關內容:\n{context}\n\n"
96
  f"請依據上面內容,寫一段約150-200字的中文文章,"
97
  f"主題:{query}。\n第{i+1}段:"
98
  )
 
99
  paragraph = call_local_inference(prompt)
100
  all_text.append(paragraph)
101
  doc.add_paragraph(paragraph)
 
110
  # -------------------------------
111
  with gr.Blocks() as demo:
112
  gr.Markdown("# 📺 電視弘法視頻生成文章 RAG 系統")
 
 
113
  query_input = gr.Textbox(lines=2, placeholder="請輸入文章主題", label="文章主題")
114
+ segments_input = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="段落數")
115
  output_text = gr.Textbox(label="生成文章")
116
  output_file = gr.File(label="下載 DOCX")
117
+ model_info = gr.Textbox(label="模型資訊")
118
 
119
  btn = gr.Button("生成文章")
120
  btn.click(
121
  generate_article_progress,
122
  inputs=[query_input, segments_input],
123
+ outputs=[output_text, output_file, model_info]
124
  )
125
 
 
 
 
126
  if __name__ == "__main__":
127
  demo.launch()