File size: 5,628 Bytes
d0ba755
 
 
 
 
 
 
4d5736b
 
d0ba755
 
 
 
 
 
 
 
 
4d5736b
d0ba755
 
 
299f87b
 
 
d0ba755
 
 
 
 
 
 
 
 
 
 
 
 
f90da5a
 
d0ba755
 
 
c6f8f84
d0ba755
 
 
 
 
 
 
 
4d5736b
d0ba755
4d5736b
d0ba755
3bcec19
 
4d5736b
 
 
 
 
fc6e44c
4d5736b
 
 
 
 
fc6e44c
4d5736b
 
fc6e44c
 
4d5736b
 
 
 
 
 
 
d0ba755
4d5736b
d0ba755
 
4d5736b
 
c6f8f84
4d5736b
d0ba755
 
4d5736b
d0ba755
3bcec19
d0ba755
 
 
 
 
 
80fe36a
d0ba755
3bcec19
132ef2d
 
d0ba755
 
132ef2d
80fe36a
d0ba755
132ef2d
4d5736b
000034c
d0ba755
4d5736b
d0ba755
c6f8f84
 
4d5736b
f90da5a
c6f8f84
3bcec19
 
 
fc6e44c
 
 
3bcec19
fc6e44c
3bcec19
 
c6f8f84
 
 
d0ba755
f90da5a
255d19f
4d5736b
3bcec19
255d19f
 
f90da5a
 
4d5736b
f90da5a
d0ba755
a6c8097
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# app.py
# -------------------------------
# 1. 套件載入
# -------------------------------
import os, glob, requests
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from docx import Document as DocxDocument
import gradio as gr

# -------------------------------
# 2. 環境變數與資料路徑
# -------------------------------
TXT_FOLDER = "./out_texts"
DB_PATH = "./faiss_db"
os.makedirs(DB_PATH, exist_ok=True)
os.makedirs(TXT_FOLDER, exist_ok=True)

HF_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
if not HF_TOKEN:
    raise ValueError(
        "請在 Hugging Face Space 的 Settings → Repository secrets 設定 HUGGINGFACEHUB_API_TOKEN"
    )

# -------------------------------
# 3. 建立或載入向量資料庫
# -------------------------------
EMBEDDINGS_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
embeddings_model = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL_NAME)

if os.path.exists(os.path.join(DB_PATH, "index.faiss")):
    print("載入現有向量資料庫...")
    db = FAISS.load_local(DB_PATH, embeddings_model, allow_dangerous_deserialization=True)
else:
    print("沒有資料庫,開始建立新向量資料庫...")
    txt_files = glob.glob(f"{TXT_FOLDER}/*.txt")
    if not txt_files:
        print("注意:TXT 資料夾中沒有任何文字檔,向量資料庫將為空。")
    docs = []
    for filepath in txt_files:
        with open(filepath, "r", encoding="utf-8") as f:
            docs.append(Document(page_content=f.read(), metadata={"source": os.path.basename(filepath)}))
    splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
    split_docs = splitter.split_documents(docs)
    db = FAISS.from_documents(split_docs, embeddings_model)
    db.save_local(DB_PATH)

retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 5})

# -------------------------------
# 4. 定義 REST API 呼叫函數
# -------------------------------
HEADERS = {"Authorization": f"Bearer {HF_TOKEN}"}

def call_hf_inference(model_name, prompt, max_new_tokens=512):
    api_url = f"https://api-inference.huggingface.co/models/{model_name}"
    payload = {
        "inputs": prompt,
        "parameters": {"max_new_tokens": max_new_tokens}
    }
    try:
        response = requests.post(api_url, headers=HEADERS, json=payload, timeout=180)  # timeout 拉長
        response.raise_for_status()
        data = response.json()
        if isinstance(data, list) and "generated_text" in data[0]:
            return data[0]["generated_text"]
        elif isinstance(data, dict) and "error" in data:
            return f"(生成失敗:{data['error']},請嘗試換一個模型)"
        else:
            return str(data)
    except requests.exceptions.ReadTimeout:
        return "(生成失敗:等待超時,請嘗試換小一點的模型或增加 timeout 秒數)"
    except Exception as e:
        return f"(生成失敗:{e})"

# -------------------------------
# 5. 查詢 API 剩餘額度
# -------------------------------
def get_hf_rate_limit():
    try:
        r = requests.get("https://huggingface.co/api/whoami", headers=HEADERS)
        r.raise_for_status()
        data = r.json()
        remaining = data.get("rate_limit", {}).get("remaining", "未知")
        return f"本小時剩餘 API 次數:約 {remaining}"
    except Exception:
        return "無法取得 API 速率資訊"

# -------------------------------
# 6. 生成文章(即時進度)
# -------------------------------
def generate_article_progress(query, model_name, segments=5):
    docx_file = "/tmp/generated_article.docx"
    doc = DocxDocument()
    doc.add_heading(query, level=1)

    all_text = []
    prompt = f"請依據下列主題生成段落:{query}\n\n每段約150-200字。"

    for i in range(int(segments)):
        paragraph = call_hf_inference(model_name, prompt)
        all_text.append(paragraph)
        doc.add_paragraph(paragraph)
        prompt = f"請接續上一段生成下一段:\n{paragraph}\n\n下一段:"

        yield "\n\n".join(all_text), None

    doc.save(docx_file)
    rate_info = get_hf_rate_limit()
    yield f"{rate_info}\n\n" + "\n\n".join(all_text), docx_file

# -------------------------------
# 7. Gradio 介面
# -------------------------------
with gr.Blocks() as demo:
    gr.Markdown("# 佛教經論 RAG 系統 (HF API)")
    gr.Markdown("使用 Hugging Face REST API + FAISS RAG,生成文章並提示 API 剩餘額度。")

    query_input = gr.Textbox(lines=2, placeholder="請輸入文章主題", label="文章主題")
    model_dropdown = gr.Dropdown(
        choices=[
            "gpt2",
            "facebook/bart-large-cnn",
            "bigscience/bloom-560m",
            "bigscience/bloomz-560m"
        ],
        value="bigscience/bloomz-560m",   # 預設比較聽得懂指令
        label="選擇生成模型"
    )
    segments_input = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="段落數")
    output_text = gr.Textbox(label="生成文章 + API 剩餘次數")
    output_file = gr.File(label="下載 DOCX")

    btn = gr.Button("生成文章")
    btn.click(
        generate_article_progress,
        inputs=[query_input, model_dropdown, segments_input],
        outputs=[output_text, output_file]
    )

# -------------------------------
# 8. 啟動 Gradio(Hugging Face Space 適用)
# -------------------------------
if __name__ == "__main__":
    demo.launch()