CHUNYU0505 commited on
Commit
a23ab36
·
verified ·
1 Parent(s): 94b2916

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -74
app.py CHANGED
@@ -10,14 +10,9 @@ from huggingface_hub import login, snapshot_download
10
  import gradio as gr
11
 
12
  # -------------------------------
13
- # 1. 模型清單(中文 & 英文)
14
  # -------------------------------
15
- MODEL_MAP = {
16
- "Auto": None,
17
- "BTLM-3B-8K": "cerebras/btlm-3b-8k-base",
18
- "GPT2-Chinese": "uer/gpt2-chinese-cluecorpusmedium", # 中文 GPT2
19
- "BART-Base": "facebook/bart-base"
20
- }
21
 
22
  HF_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
23
  if HF_TOKEN:
@@ -25,57 +20,35 @@ if HF_TOKEN:
25
  print("✅ 已使用 HUGGINGFACEHUB_API_TOKEN 登入 Hugging Face")
26
 
27
  # -------------------------------
28
- # 2. 預先下載模型
29
  # -------------------------------
30
- LOCAL_MODEL_DIRS = {}
31
- for name, repo in MODEL_MAP.items():
32
- if repo is None:
33
- continue
34
- try:
35
- local_dir = f"./models/{repo.split('/')[-1]}"
36
- if not os.path.exists(local_dir):
37
- print(f"⬇️ 正在下載模型 {repo} ...")
38
- snapshot_download(repo_id=repo, token=HF_TOKEN, local_dir=local_dir)
39
- LOCAL_MODEL_DIRS[name] = local_dir
40
- except Exception as e:
41
- print(f"⚠️ 模型 {repo} 無法下載: {e}")
42
 
43
  # -------------------------------
44
  # 3. pipeline 載入
45
  # -------------------------------
46
- _loaded_pipelines = {}
47
-
48
- def get_pipeline(model_name):
49
- if model_name not in _loaded_pipelines:
50
- local_path = LOCAL_MODEL_DIRS.get(model_name)
51
- print(f"🔄 正在載入模型 {model_name} from {local_path}")
52
-
53
- if model_name == "BTLM-3B-8K":
54
- tokenizer = AutoTokenizer.from_pretrained(local_path, trust_remote_code=True)
55
- model = AutoModelForCausalLM.from_pretrained(local_path, trust_remote_code=True)
56
- else:
57
- tokenizer = AutoTokenizer.from_pretrained(local_path)
58
- model = AutoModelForCausalLM.from_pretrained(local_path)
59
-
60
- # 修正 pad_token 缺失問題
61
- if tokenizer.pad_token is None:
62
- tokenizer.pad_token = tokenizer.eos_token
63
-
64
- generator = pipeline(
65
- "text-generation",
66
- model=model,
67
- tokenizer=tokenizer,
68
- device=-1 # CPU
69
- )
70
- _loaded_pipelines[model_name] = generator
71
- return _loaded_pipelines[model_name]
72
-
73
- def call_local_inference(model_name, prompt, max_new_tokens=256):
74
  try:
75
- generator = get_pipeline(model_name)
76
-
77
- # ✅ 強制中文模式:補上提示
78
- if "中文" not in prompt and "Chinese" not in prompt:
79
  prompt += "\n(請用中文回答)"
80
 
81
  outputs = generator(
@@ -83,55 +56,41 @@ def call_local_inference(model_name, prompt, max_new_tokens=256):
83
  max_new_tokens=max_new_tokens,
84
  do_sample=True,
85
  temperature=0.7,
86
- pad_token_id=generator.tokenizer.pad_token_id
87
  )
88
  return outputs[0]["generated_text"]
89
  except Exception as e:
90
  return f"(生成失敗:{e})"
91
 
92
  # -------------------------------
93
- # 4. Auto 模式
94
  # -------------------------------
95
- def pick_model_auto(segments):
96
- if segments <= 3:
97
- return "GPT2-Chinese" # 短文 → 中文 GPT2
98
- elif segments <= 6:
99
- return "BTLM-3B-8K"
100
- else:
101
- return "BART-Base"
102
-
103
- def generate_article_progress(query, model_name, segments=5):
104
  docx_file = "/tmp/generated_article.docx"
105
  doc = DocxDocument()
106
  doc.add_heading(query, level=1)
107
 
108
- selected_model = pick_model_auto(segments) if model_name == "Auto" else model_name
109
- print(f"👉 使用模型: {selected_model}")
110
-
111
  all_text = []
112
  base_prompt = f"請依據下列主題生成一篇中文文章,主題:{query}\n每段約150-200字。\n"
113
 
114
  for i in range(segments):
115
- # ✅ 每段獨立生成
116
  prompt = base_prompt + f"第{i+1}段:"
117
- paragraph = call_local_inference(selected_model, prompt)
118
  all_text.append(paragraph)
119
  doc.add_paragraph(paragraph)
120
-
121
- yield "\n\n".join(all_text), None, f"本次使用模型:{selected_model}"
122
 
123
  doc.save(docx_file)
124
- yield "\n\n".join(all_text), docx_file, f"本次使用模型:{selected_model}"
125
 
126
  # -------------------------------
127
  # 5. Gradio 介面
128
  # -------------------------------
129
  with gr.Blocks() as demo:
130
  gr.Markdown("# 📺 電視弘法視頻生成文章RAG系統")
131
- gr.Markdown("支援 GPT2-Chinese / BTLM-3B / BART-Base,Auto 模式會自動選擇,並強制中文輸出。")
132
 
133
  query_input = gr.Textbox(lines=2, placeholder="請輸入文章主題", label="文章主題")
134
- model_dropdown = gr.Dropdown(choices=list(MODEL_MAP.keys()), value="Auto", label="選擇生成模型")
135
  segments_input = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="段落數")
136
  output_text = gr.Textbox(label="生成文章")
137
  output_file = gr.File(label="下載 DOCX")
@@ -140,7 +99,7 @@ with gr.Blocks() as demo:
140
  btn = gr.Button("生成文章")
141
  btn.click(
142
  generate_article_progress,
143
- inputs=[query_input, model_dropdown, segments_input],
144
  outputs=[output_text, output_file, model_used_text]
145
  )
146
 
 
10
  import gradio as gr
11
 
12
  # -------------------------------
13
+ # 1. 模型清單(只用中文 GPT2)
14
  # -------------------------------
15
+ MODEL_NAME = "uer/gpt2-chinese-cluecorpusmedium"
 
 
 
 
 
16
 
17
  HF_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
18
  if HF_TOKEN:
 
20
  print("✅ 已使用 HUGGINGFACEHUB_API_TOKEN 登入 Hugging Face")
21
 
22
  # -------------------------------
23
+ # 2. 下載模型
24
  # -------------------------------
25
+ LOCAL_MODEL_DIR = f"./models/{MODEL_NAME.split('/')[-1]}"
26
+ if not os.path.exists(LOCAL_MODEL_DIR):
27
+ print(f"⬇️ 正在下載模型 {MODEL_NAME} ...")
28
+ snapshot_download(repo_id=MODEL_NAME, token=HF_TOKEN, local_dir=LOCAL_MODEL_DIR)
 
 
 
 
 
 
 
 
29
 
30
  # -------------------------------
31
  # 3. pipeline 載入
32
  # -------------------------------
33
+ print(f"🔄 載入中文模型 {MODEL_NAME}")
34
+ tokenizer = AutoTokenizer.from_pretrained(LOCAL_MODEL_DIR)
35
+ model = AutoModelForCausalLM.from_pretrained(LOCAL_MODEL_DIR)
36
+
37
+ # 修正 pad_token 缺失問題
38
+ if tokenizer.pad_token is None:
39
+ tokenizer.pad_token = tokenizer.eos_token
40
+
41
+ generator = pipeline(
42
+ "text-generation",
43
+ model=model,
44
+ tokenizer=tokenizer,
45
+ device=-1 # CPU
46
+ )
47
+
48
+ def call_local_inference(prompt, max_new_tokens=256):
 
 
 
 
 
 
 
 
 
 
 
 
49
  try:
50
+ # 強制補充中文提示
51
+ if "中文" not in prompt:
 
 
52
  prompt += "\n(請用中文回答)"
53
 
54
  outputs = generator(
 
56
  max_new_tokens=max_new_tokens,
57
  do_sample=True,
58
  temperature=0.7,
59
+ pad_token_id=tokenizer.pad_token_id
60
  )
61
  return outputs[0]["generated_text"]
62
  except Exception as e:
63
  return f"(生成失敗:{e})"
64
 
65
  # -------------------------------
66
+ # 4. 文章生成
67
  # -------------------------------
68
+ def generate_article_progress(query, segments=5):
 
 
 
 
 
 
 
 
69
  docx_file = "/tmp/generated_article.docx"
70
  doc = DocxDocument()
71
  doc.add_heading(query, level=1)
72
 
 
 
 
73
  all_text = []
74
  base_prompt = f"請依據下列主題生成一篇中文文章,主題:{query}\n每段約150-200字。\n"
75
 
76
  for i in range(segments):
 
77
  prompt = base_prompt + f"第{i+1}段:"
78
+ paragraph = call_local_inference(prompt)
79
  all_text.append(paragraph)
80
  doc.add_paragraph(paragraph)
81
+ yield "\n\n".join(all_text), None, f"本次使用模型:{MODEL_NAME}"
 
82
 
83
  doc.save(docx_file)
84
+ yield "\n\n".join(all_text), docx_file, f"本次使用模型:{MODEL_NAME}"
85
 
86
  # -------------------------------
87
  # 5. Gradio 介面
88
  # -------------------------------
89
  with gr.Blocks() as demo:
90
  gr.Markdown("# 📺 電視弘法視頻生成文章RAG系統")
91
+ gr.Markdown("固定使用 **GPT2-Chinese (uer/gpt2-chinese-cluecorpusmedium)** 生成中文文章。")
92
 
93
  query_input = gr.Textbox(lines=2, placeholder="請輸入文章主題", label="文章主題")
 
94
  segments_input = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="段落數")
95
  output_text = gr.Textbox(label="生成文章")
96
  output_file = gr.File(label="下載 DOCX")
 
99
  btn = gr.Button("生成文章")
100
  btn.click(
101
  generate_article_progress,
102
+ inputs=[query_input, segments_input],
103
  outputs=[output_text, output_file, model_used_text]
104
  )
105