File size: 3,558 Bytes
2aa3d8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0ba755
fb13185
d0ba755
fb13185
 
 
 
 
 
 
 
d0ba755
3bcec19
d0ba755
 
 
 
fb13185
 
 
 
 
 
 
d0ba755
 
80fe36a
d0ba755
fb13185
132ef2d
 
d0ba755
 
fb13185
80fe36a
d0ba755
fb13185
000034c
d0ba755
fb13185
d0ba755
c6f8f84
fb13185
 
f90da5a
c6f8f84
3bcec19
fb13185
 
3bcec19
 
c6f8f84
fb13185
c6f8f84
fb13185
d0ba755
f90da5a
255d19f
4d5736b
3bcec19
fb13185
255d19f
f90da5a
 
fb13185
f90da5a
d0ba755
a6c8097
2aa3d8b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# -------------------------------
# 4. 本地推論模型設定
# -------------------------------
MODEL_MAP = {
    "Auto": None,  # 自動選擇
    "Gemma-2B": "google/gemma-2b",
    "Gemma-7B": "google/gemma-7b",
    "BTLM-3B-8K": "cerebras/btlm-3b-8k",
    "gpt-oss-20B": "openai-community/gpt-oss-20b"
}

# 快取 pipeline 避免每次重建
_loaded_pipelines = {}

def get_pipeline(model_name):
    if model_name not in _loaded_pipelines:
        print(f"🔄 正在載入模型 {model_name} ...")
        model_id = MODEL_MAP[model_name]
        generator = pipeline(
            "text-generation",
            model=model_id,
            tokenizer=model_id,
            device_map="auto",
        )
        _loaded_pipelines[model_name] = generator
    return _loaded_pipelines[model_name]

def call_local_inference(model_name, prompt, max_new_tokens=512):
    try:
        generator = get_pipeline(model_name)
        outputs = generator(prompt, max_new_tokens=max_new_tokens, do_sample=True, temperature=0.7)
        return outputs[0]["generated_text"]
    except Exception as e:
        return f"(生成失敗:{e})"

# -------------------------------
# 5. 生成文章(即時進度)
# -------------------------------
def pick_model_auto(segments):
    """根據段落數自動挑選模型"""
    if segments <= 3:
        return "Gemma-2B"
    elif segments <= 6:
        return "BTLM-3B-8K"
    else:
        return "gpt-oss-20B"

def generate_article_progress(query, model_name, segments=5):
    docx_file = "/tmp/generated_article.docx"
    doc = DocxDocument()
    doc.add_heading(query, level=1)

    # 自動挑模型
    if model_name == "Auto":
        selected_model = pick_model_auto(int(segments))
    else:
        selected_model = model_name
    print(f"👉 使用模型: {selected_model}")

    all_text = []
    prompt = f"請依據下列主題生成段落:{query}\n\n每段約150-200字。"

    for i in range(int(segments)):
        paragraph = call_local_inference(selected_model, prompt)
        all_text.append(paragraph)
        doc.add_paragraph(paragraph)
        prompt = f"請接續上一段生成下一段:\n{paragraph}\n\n下一段:"

        yield "\n\n".join(all_text), None, f"本次使用模型:{selected_model}"

    doc.save(docx_file)
    yield "\n\n".join(all_text), docx_file, f"本次使用模型:{selected_model}"

# -------------------------------
# 6. Gradio 介面
# -------------------------------
with gr.Blocks() as demo:
    gr.Markdown("# 佛教經論 RAG 系統 (本地推論 + Auto 模型選擇)")
    gr.Markdown("使用 Hugging Face Space + FAISS RAG,本地模型推論,不消耗 API 額度。")

    query_input = gr.Textbox(lines=2, placeholder="請輸入文章主題", label="文章主題")
    model_dropdown = gr.Dropdown(
        choices=list(MODEL_MAP.keys()),
        value="Auto",   # 預設自動模式
        label="選擇生成模型"
    )
    segments_input = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="段落數")
    output_text = gr.Textbox(label="生成文章")
    output_file = gr.File(label="下載 DOCX")
    model_used_text = gr.Textbox(label="實際使用模型", interactive=False)

    btn = gr.Button("生成文章")
    btn.click(
        generate_article_progress,
        inputs=[query_input, model_dropdown, segments_input],
        outputs=[output_text, output_file, model_used_text]
    )

# -------------------------------
# 7. 啟動 Gradio
# -------------------------------
if __name__ == "__main__":
    demo.launch()