RAG_Test_System / app.py
CHUNYU0505's picture
Update app.py
2aa3d8b verified
raw
history blame
3.56 kB
# -------------------------------
# 4. 本地推論模型設定
# -------------------------------
MODEL_MAP = {
"Auto": None, # 自動選擇
"Gemma-2B": "google/gemma-2b",
"Gemma-7B": "google/gemma-7b",
"BTLM-3B-8K": "cerebras/btlm-3b-8k",
"gpt-oss-20B": "openai-community/gpt-oss-20b"
}
# 快取 pipeline 避免每次重建
_loaded_pipelines = {}
def get_pipeline(model_name):
if model_name not in _loaded_pipelines:
print(f"🔄 正在載入模型 {model_name} ...")
model_id = MODEL_MAP[model_name]
generator = pipeline(
"text-generation",
model=model_id,
tokenizer=model_id,
device_map="auto",
)
_loaded_pipelines[model_name] = generator
return _loaded_pipelines[model_name]
def call_local_inference(model_name, prompt, max_new_tokens=512):
try:
generator = get_pipeline(model_name)
outputs = generator(prompt, max_new_tokens=max_new_tokens, do_sample=True, temperature=0.7)
return outputs[0]["generated_text"]
except Exception as e:
return f"(生成失敗:{e})"
# -------------------------------
# 5. 生成文章(即時進度)
# -------------------------------
def pick_model_auto(segments):
"""根據段落數自動挑選模型"""
if segments <= 3:
return "Gemma-2B"
elif segments <= 6:
return "BTLM-3B-8K"
else:
return "gpt-oss-20B"
def generate_article_progress(query, model_name, segments=5):
docx_file = "/tmp/generated_article.docx"
doc = DocxDocument()
doc.add_heading(query, level=1)
# 自動挑模型
if model_name == "Auto":
selected_model = pick_model_auto(int(segments))
else:
selected_model = model_name
print(f"👉 使用模型: {selected_model}")
all_text = []
prompt = f"請依據下列主題生成段落:{query}\n\n每段約150-200字。"
for i in range(int(segments)):
paragraph = call_local_inference(selected_model, prompt)
all_text.append(paragraph)
doc.add_paragraph(paragraph)
prompt = f"請接續上一段生成下一段:\n{paragraph}\n\n下一段:"
yield "\n\n".join(all_text), None, f"本次使用模型:{selected_model}"
doc.save(docx_file)
yield "\n\n".join(all_text), docx_file, f"本次使用模型:{selected_model}"
# -------------------------------
# 6. Gradio 介面
# -------------------------------
with gr.Blocks() as demo:
gr.Markdown("# 佛教經論 RAG 系統 (本地推論 + Auto 模型選擇)")
gr.Markdown("使用 Hugging Face Space + FAISS RAG,本地模型推論,不消耗 API 額度。")
query_input = gr.Textbox(lines=2, placeholder="請輸入文章主題", label="文章主題")
model_dropdown = gr.Dropdown(
choices=list(MODEL_MAP.keys()),
value="Auto", # 預設自動模式
label="選擇生成模型"
)
segments_input = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="段落數")
output_text = gr.Textbox(label="生成文章")
output_file = gr.File(label="下載 DOCX")
model_used_text = gr.Textbox(label="實際使用模型", interactive=False)
btn = gr.Button("生成文章")
btn.click(
generate_article_progress,
inputs=[query_input, model_dropdown, segments_input],
outputs=[output_text, output_file, model_used_text]
)
# -------------------------------
# 7. 啟動 Gradio
# -------------------------------
if __name__ == "__main__":
demo.launch()