CHUNYU0505 commited on
Commit
fc6e44c
·
verified ·
1 Parent(s): 3bcec19

改成小模型

Browse files

大模型跑不動

Files changed (1) hide show
  1. app.py +8 -6
app.py CHANGED
@@ -61,15 +61,17 @@ def call_hf_inference(model_name, prompt, max_new_tokens=512):
61
  "parameters": {"max_new_tokens": max_new_tokens}
62
  }
63
  try:
64
- response = requests.post(api_url, headers=HEADERS, json=payload, timeout=60)
65
  response.raise_for_status()
66
  data = response.json()
67
  if isinstance(data, list) and "generated_text" in data[0]:
68
  return data[0]["generated_text"]
69
  elif isinstance(data, dict) and "error" in data:
70
- return f"(生成失敗:{data['error']}"
71
  else:
72
  return str(data)
 
 
73
  except Exception as e:
74
  return f"(生成失敗:{e})"
75
 
@@ -120,11 +122,11 @@ with gr.Blocks() as demo:
120
  model_dropdown = gr.Dropdown(
121
  choices=[
122
  "gpt2",
123
- "EleutherAI/gpt-neo-2.7B",
124
- "EleutherAI/gpt-j-6B",
125
- "facebook/bart-large-cnn"
126
  ],
127
- value="EleutherAI/gpt-neo-2.7B",
128
  label="選擇生成模型"
129
  )
130
  segments_input = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="段落數")
 
61
  "parameters": {"max_new_tokens": max_new_tokens}
62
  }
63
  try:
64
+ response = requests.post(api_url, headers=HEADERS, json=payload, timeout=180) # timeout 拉長
65
  response.raise_for_status()
66
  data = response.json()
67
  if isinstance(data, list) and "generated_text" in data[0]:
68
  return data[0]["generated_text"]
69
  elif isinstance(data, dict) and "error" in data:
70
+ return f"(生成失敗:{data['error']},請嘗試換一個模型)"
71
  else:
72
  return str(data)
73
+ except requests.exceptions.ReadTimeout:
74
+ return "(生成失敗:等待超時,請嘗試換小一點的模型或增加 timeout 秒數)"
75
  except Exception as e:
76
  return f"(生成失敗:{e})"
77
 
 
122
  model_dropdown = gr.Dropdown(
123
  choices=[
124
  "gpt2",
125
+ "facebook/bart-large-cnn",
126
+ "bigscience/bloom-560m",
127
+ "bigscience/bloomz-560m"
128
  ],
129
+ value="bigscience/bloomz-560m", # 預設比較聽得懂指令
130
  label="選擇生成模型"
131
  )
132
  segments_input = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="段落數")