CHUNYU0505 commited on
Commit
94b2916
·
verified ·
1 Parent(s): dc31505

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -9
app.py CHANGED
@@ -10,12 +10,12 @@ from huggingface_hub import login, snapshot_download
10
  import gradio as gr
11
 
12
  # -------------------------------
13
- # 1. 模型清單
14
  # -------------------------------
15
  MODEL_MAP = {
16
  "Auto": None,
17
  "BTLM-3B-8K": "cerebras/btlm-3b-8k-base",
18
- "DistilGPT2": "distilgpt2",
19
  "BART-Base": "facebook/bart-base"
20
  }
21
 
@@ -29,7 +29,7 @@ if HF_TOKEN:
29
  # -------------------------------
30
  LOCAL_MODEL_DIRS = {}
31
  for name, repo in MODEL_MAP.items():
32
- if repo is None:
33
  continue
34
  try:
35
  local_dir = f"./models/{repo.split('/')[-1]}"
@@ -65,7 +65,7 @@ def get_pipeline(model_name):
65
  "text-generation",
66
  model=model,
67
  tokenizer=tokenizer,
68
- device=-1 # CPU 強制
69
  )
70
  _loaded_pipelines[model_name] = generator
71
  return _loaded_pipelines[model_name]
@@ -73,6 +73,11 @@ def get_pipeline(model_name):
73
  def call_local_inference(model_name, prompt, max_new_tokens=256):
74
  try:
75
  generator = get_pipeline(model_name)
 
 
 
 
 
76
  outputs = generator(
77
  prompt,
78
  max_new_tokens=max_new_tokens,
@@ -89,7 +94,7 @@ def call_local_inference(model_name, prompt, max_new_tokens=256):
89
  # -------------------------------
90
  def pick_model_auto(segments):
91
  if segments <= 3:
92
- return "DistilGPT2"
93
  elif segments <= 6:
94
  return "BTLM-3B-8K"
95
  else:
@@ -104,13 +109,15 @@ def generate_article_progress(query, model_name, segments=5):
104
  print(f"👉 使用模型: {selected_model}")
105
 
106
  all_text = []
107
- prompt = f"請依據下列主題生成段落:{query}\n\n每段約150-200字。"
108
 
109
  for i in range(segments):
 
 
110
  paragraph = call_local_inference(selected_model, prompt)
111
  all_text.append(paragraph)
112
  doc.add_paragraph(paragraph)
113
- prompt = f"請接續上一段生成下一段:\n{paragraph}\n\n下一段:"
114
  yield "\n\n".join(all_text), None, f"本次使用模型:{selected_model}"
115
 
116
  doc.save(docx_file)
@@ -120,8 +127,8 @@ def generate_article_progress(query, model_name, segments=5):
120
  # 5. Gradio 介面
121
  # -------------------------------
122
  with gr.Blocks() as demo:
123
- gr.Markdown("# 📺 電視弘法視頻生成文章RAG系統(CPU免費版))")
124
- gr.Markdown("支援 DistilGPT2 / BTLM-3B / BART-Base,Auto 模式會自動選擇。")
125
 
126
  query_input = gr.Textbox(lines=2, placeholder="請輸入文章主題", label="文章主題")
127
  model_dropdown = gr.Dropdown(choices=list(MODEL_MAP.keys()), value="Auto", label="選擇生成模型")
 
10
  import gradio as gr
11
 
12
  # -------------------------------
13
+ # 1. 模型清單(中文 & 英文)
14
  # -------------------------------
15
  MODEL_MAP = {
16
  "Auto": None,
17
  "BTLM-3B-8K": "cerebras/btlm-3b-8k-base",
18
+ "GPT2-Chinese": "uer/gpt2-chinese-cluecorpusmedium", # 中文 GPT2
19
  "BART-Base": "facebook/bart-base"
20
  }
21
 
 
29
  # -------------------------------
30
  LOCAL_MODEL_DIRS = {}
31
  for name, repo in MODEL_MAP.items():
32
+ if repo is None:
33
  continue
34
  try:
35
  local_dir = f"./models/{repo.split('/')[-1]}"
 
65
  "text-generation",
66
  model=model,
67
  tokenizer=tokenizer,
68
+ device=-1 # CPU
69
  )
70
  _loaded_pipelines[model_name] = generator
71
  return _loaded_pipelines[model_name]
 
73
  def call_local_inference(model_name, prompt, max_new_tokens=256):
74
  try:
75
  generator = get_pipeline(model_name)
76
+
77
+ # ✅ 強制中文模式:補上提示
78
+ if "中文" not in prompt and "Chinese" not in prompt:
79
+ prompt += "\n(請用中文回答)"
80
+
81
  outputs = generator(
82
  prompt,
83
  max_new_tokens=max_new_tokens,
 
94
  # -------------------------------
95
  def pick_model_auto(segments):
96
  if segments <= 3:
97
+ return "GPT2-Chinese" # 短文 → 中文 GPT2
98
  elif segments <= 6:
99
  return "BTLM-3B-8K"
100
  else:
 
109
  print(f"👉 使用模型: {selected_model}")
110
 
111
  all_text = []
112
+ base_prompt = f"請依據下列主題生成一篇中文文章,主題:{query}\n每段約150-200字。\n"
113
 
114
  for i in range(segments):
115
+ # ✅ 每段獨立生成
116
+ prompt = base_prompt + f"第{i+1}段:"
117
  paragraph = call_local_inference(selected_model, prompt)
118
  all_text.append(paragraph)
119
  doc.add_paragraph(paragraph)
120
+
121
  yield "\n\n".join(all_text), None, f"本次使用模型:{selected_model}"
122
 
123
  doc.save(docx_file)
 
127
  # 5. Gradio 介面
128
  # -------------------------------
129
  with gr.Blocks() as demo:
130
+ gr.Markdown("# 📺 電視弘法視頻生成文章RAG系統")
131
+ gr.Markdown("支援 GPT2-Chinese / BTLM-3B / BART-Base,Auto 模式會自動選擇,並強制中文輸出。")
132
 
133
  query_input = gr.Textbox(lines=2, placeholder="請輸入文章主題", label="文章主題")
134
  model_dropdown = gr.Dropdown(choices=list(MODEL_MAP.keys()), value="Auto", label="選擇生成模型")