File size: 4,715 Bytes
058eba2
8f7234f
d051231
 
 
 
 
052f25a
d051231
 
 
1740855
052f25a
d051231
052f25a
 
9c1b3ba
d051231
 
 
 
 
052f25a
 
 
d051231
06f5c87
 
 
052f25a
06f5c87
6d8dd36
052f25a
d051231
 
052f25a
06f5c87
 
 
 
052f25a
06f5c87
052f25a
06f5c87
052f25a
6d8dd36
8f7234f
052f25a
06f5c87
 
052f25a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0ba755
6d8dd36
d0ba755
 
 
 
6f0de2e
6d8dd36
 
 
 
 
80fe36a
06f5c87
6f0de2e
052f25a
 
 
 
 
6f0de2e
052f25a
a23ab36
132ef2d
06f5c87
052f25a
6d8dd36
6f0de2e
06f5c87
6d8dd36
80fe36a
d0ba755
052f25a
d0ba755
c6f8f84
06f5c87
052f25a
058eba2
c6f8f84
052f25a
fb13185
c6f8f84
052f25a
d0ba755
f90da5a
255d19f
06f5c87
a23ab36
6d8dd36
255d19f
f90da5a
d0ba755
a6c8097
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# app.py
import os, torch
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from docx import Document as DocxDocument
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from huggingface_hub import login, snapshot_download
import gradio as gr

# -------------------------------
# 1. 模型設定(專門中文,T5)
# -------------------------------
MODEL_NAME = "Langboat/mengzi-t5-base"  # ✅ CPU 也能跑的中文 T5
LOCAL_MODEL_DIR = f"./models/{MODEL_NAME.split('/')[-1]}"

HF_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
if HF_TOKEN:
    login(token=HF_TOKEN)
    print("✅ 已使用 HUGGINGFACEHUB_API_TOKEN 登入 Hugging Face")

if not os.path.exists(LOCAL_MODEL_DIR):
    print(f"⬇️ 嘗試下載模型 {MODEL_NAME} ...")
    snapshot_download(repo_id=MODEL_NAME, token=HF_TOKEN, local_dir=LOCAL_MODEL_DIR)

print(f"👉 最終使用模型:{MODEL_NAME}")

# -------------------------------
# 2. 載入 tokenizer + model
# -------------------------------
tokenizer = AutoTokenizer.from_pretrained(LOCAL_MODEL_DIR)
model = AutoModelForSeq2SeqLM.from_pretrained(LOCAL_MODEL_DIR, device_map="cpu")

# -------------------------------
# 3. 向量資料庫載入
# -------------------------------
EMBEDDINGS_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
embeddings_model = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL_NAME)

if os.path.exists("./faiss_db/index.faiss"):
    print("✅ 載入現有向量資料庫...")
    db = FAISS.load_local("./faiss_db", embeddings_model, allow_dangerous_deserialization=True)
else:
    print("⚠️ 找不到向量資料庫,請先建立")
    db = None

retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 5}) if db else None

# -------------------------------
# 4. 改良推理函數(避免重複亂碼)
# -------------------------------
def call_local_inference(prompt, max_new_tokens=256):
    try:
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device)
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,            # ❌ 關掉隨機
            num_beams=4,                # ✅ 用 beam search
            repetition_penalty=1.2,     # ✅ 懲罰重複
            length_penalty=1.0,
            early_stopping=True
        )
        return tokenizer.decode(outputs[0], skip_special_tokens=True)
    except Exception as e:
        return f"(生成失敗:{e})"

# -------------------------------
# 5. 文章生成(加入 RAG)
# -------------------------------
def generate_article_progress(query, segments=5):
    docx_file = "/tmp/generated_article.docx"
    doc = DocxDocument()
    doc.add_heading(query, level=1)
    all_text = []

    context = ""
    if retriever:
        retrieved_docs = retriever.get_relevant_documents(query)
        context_texts = [d.page_content for d in retrieved_docs]
        context = "\n".join([f"{i+1}. {txt}" for i, txt in enumerate(context_texts[:3])])

    for i in range(segments):
        prompt = (
            f"請基於以下資料,撰寫一段中文文章:\n"
            f"主題:{query}\n"
            f"要求:字數約150~200字,內容要有完整句子,不要重複詞語。\n\n"
            f"參考資料:\n{context}\n\n"
            f"第{i+1}段:"
        )

        paragraph = call_local_inference(prompt)
        all_text.append(paragraph)
        doc.add_paragraph(paragraph)

        yield "\n\n".join(all_text), None, f"本次使用模型:{MODEL_NAME}"

    doc.save(docx_file)
    yield "\n\n".join(all_text), docx_file, f"本次使用模型:{MODEL_NAME}"

# -------------------------------
# 6. Gradio 介面
# -------------------------------
with gr.Blocks() as demo:
    gr.Markdown("# 📺 電視弘法視頻生成文章 RAG 系統")
    gr.Markdown("基於向量資料庫 + 中文 T5 模型,自動生成主題文章")

    query_input = gr.Textbox(lines=2, placeholder="請輸入文章主題", label="文章主題")
    segments_input = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="段落數")
    output_text = gr.Textbox(label="生成文章")
    output_file = gr.File(label="下載 DOCX")
    model_info = gr.Textbox(label="模型資訊")

    btn = gr.Button("生成文章")
    btn.click(
        generate_article_progress,
        inputs=[query_input, segments_input],
        outputs=[output_text, output_file, model_info]
    )

if __name__ == "__main__":
    demo.launch()