linoyts HF Staff commited on
Commit
e0ec356
·
verified ·
1 Parent(s): 592dbba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -55
app.py CHANGED
@@ -90,29 +90,91 @@ Please strictly follow the rewriting rules below:
90
  "Rewritten": "..."
91
  }
92
  '''
93
-
94
- def polish_prompt(prompt, img):
 
 
 
 
 
95
  prompt = f"{SYSTEM_PROMPT}\n\nUser Input: {prompt}\n\nRewritten Prompt:"
96
- success=False
97
- while not success:
98
- try:
99
- result = api(prompt, [img])
100
- # print(f"Result: {result}")
101
- # print(f"Polished Prompt: {polished_prompt}")
102
- if isinstance(result, str):
103
- result = result.replace('```json','')
104
- result = result.replace('```','')
105
- result = json.loads(result)
106
- else:
107
- result = json.loads(result)
108
-
109
- polished_prompt = result['Rewritten']
110
- polished_prompt = polished_prompt.strip()
111
- polished_prompt = polished_prompt.replace("\n", " ")
112
- success = True
113
- except Exception as e:
114
- print(f"[Warning] Error during API call: {e}")
115
- return polished_prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
 
118
  def encode_image(pil_image):
@@ -122,37 +184,35 @@ def encode_image(pil_image):
122
  return base64.b64encode(buffered.getvalue()).decode("utf-8")
123
 
124
 
125
-
126
-
127
- def api(prompt, img_list, model="qwen-vl-max-latest", kwargs={}):
128
- import dashscope
129
- api_key = os.environ.get('DASH_API_KEY')
130
- if not api_key:
131
- raise EnvironmentError("DASH_API_KEY is not set")
132
- assert model in ["qwen-vl-max-latest"], f"Not implemented model {model}"
133
- sys_promot = "you are a helpful assistant, you should provide useful answers to users."
134
- messages = [
135
- {"role": "system", "content": sys_promot},
136
- {"role": "user", "content": []}]
137
- for img in img_list:
138
- messages[1]["content"].append(
139
- {"image": f"data:image/png;base64,{encode_image(img)}"})
140
- messages[1]["content"].append({"text": f"{prompt}"})
141
-
142
- response_format = kwargs.get('response_format', None)
143
-
144
- response = dashscope.MultiModalConversation.call(
145
- api_key=api_key,
146
- model=model, # For example, use qwen-plus here. You can change the model name as needed. Model list: https://help.aliyun.com/zh/model-studio/getting-started/models
147
- messages=messages,
148
- result_format='message',
149
- response_format=response_format,
150
- )
151
-
152
- if response.status_code == 200:
153
- return response.output.choices[0].message.content[0]['text']
154
- else:
155
- raise Exception(f'Failed to post: {response}')
156
 
157
  # --- Model Loading ---
158
  dtype = torch.bfloat16
@@ -172,7 +232,7 @@ def infer(
172
  seed=42,
173
  randomize_seed=False,
174
  true_guidance_scale=1.0,
175
- num_inference_steps=50,
176
  height=None,
177
  width=None,
178
  rewrite_prompt=True,
@@ -211,7 +271,7 @@ def infer(
211
  print(f"Negative Prompt: '{negative_prompt}'")
212
  print(f"Seed: {seed}, Steps: {num_inference_steps}, Guidance: {true_guidance_scale}, Size: {width}x{height}")
213
  if rewrite_prompt and len(pil_images) > 0:
214
- prompt = polish_prompt(prompt, pil_images[0])
215
  print(f"Rewritten Prompt: {prompt}")
216
 
217
 
 
90
  "Rewritten": "..."
91
  }
92
  '''
93
+ # --- Prompt Enhancement using Hugging Face InferenceClient ---
94
+ def polish_prompt_hf(prompt, img):
95
+ """
96
+ Rewrites the prompt using a Hugging Face InferenceClient.
97
+ """
98
+ # Ensure HF_TOKEN is set
99
+ api_key = os.environ.get("HF_TOKEN")
100
  prompt = f"{SYSTEM_PROMPT}\n\nUser Input: {prompt}\n\nRewritten Prompt:"
101
+ if not api_key:
102
+ print("Warning: HF_TOKEN not set. Falling back to original prompt.")
103
+ return original_prompt
104
+
105
+ try:
106
+ # Initialize the client
107
+ client = InferenceClient(
108
+ provider="cerebras",
109
+ api_key=api_key,
110
+ )
111
+
112
+ # Format the messages for the chat completions API
113
+ messages = [
114
+ {"role": "system", "content": system_prompt},
115
+ {"role": "user", "content": prompt}
116
+ ]
117
+
118
+ sys_promot = "you are a helpful assistant, you should provide useful answers to users."
119
+ messages = [
120
+ {"role": "system", "content": sys_promot},
121
+ {"role": "user", "content": []}]
122
+ for img in img_list:
123
+ messages[1]["content"].append(
124
+ {"image": f"data:image/png;base64,{encode_image(img)}"})
125
+ messages[1]["content"].append({"text": f"{prompt}"})
126
+
127
+ # Call the API
128
+ completion = client.chat.completions.create(
129
+ model="Qwen/Qwen3-235B-A22B-Instruct-2507",
130
+ messages=messages,
131
+ )
132
+
133
+ # Parse the response
134
+ result = completion.choices[0].message.content
135
+
136
+ # Try to extract JSON if present
137
+ if '{"Rewritten"' in result:
138
+ try:
139
+ # Clean up the response
140
+ result = result.replace('```json', '').replace('```', '')
141
+ result_json = json.loads(result)
142
+ polished_prompt = result_json.get('Rewritten', result)
143
+ except:
144
+ polished_prompt = result
145
+ else:
146
+ polished_prompt = result
147
+
148
+ polished_prompt = polished_prompt.strip().replace("\n", " ")
149
+ return polished_prompt
150
+
151
+ except Exception as e:
152
+ print(f"Error during API call to Hugging Face: {e}")
153
+ # Fallback to original prompt if enhancement fails
154
+ return original_prompt
155
+
156
+ # def polish_prompt(prompt, img):
157
+ # prompt = f"{SYSTEM_PROMPT}\n\nUser Input: {prompt}\n\nRewritten Prompt:"
158
+ # success=False
159
+ # while not success:
160
+ # try:
161
+ # result = api(prompt, [img])
162
+ # # print(f"Result: {result}")
163
+ # # print(f"Polished Prompt: {polished_prompt}")
164
+ # if isinstance(result, str):
165
+ # result = result.replace('```json','')
166
+ # result = result.replace('```','')
167
+ # result = json.loads(result)
168
+ # else:
169
+ # result = json.loads(result)
170
+
171
+ # polished_prompt = result['Rewritten']
172
+ # polished_prompt = polished_prompt.strip()
173
+ # polished_prompt = polished_prompt.replace("\n", " ")
174
+ # success = True
175
+ # except Exception as e:
176
+ # print(f"[Warning] Error during API call: {e}")
177
+ # return polished_prompt
178
 
179
 
180
  def encode_image(pil_image):
 
184
  return base64.b64encode(buffered.getvalue()).decode("utf-8")
185
 
186
 
187
+ # def api(prompt, img_list, model="qwen-vl-max-latest", kwargs={}):
188
+ # import dashscope
189
+ # api_key = os.environ.get('DASH_API_KEY')
190
+ # if not api_key:
191
+ # raise EnvironmentError("DASH_API_KEY is not set")
192
+ # assert model in ["qwen-vl-max-latest"], f"Not implemented model {model}"
193
+ # sys_promot = "you are a helpful assistant, you should provide useful answers to users."
194
+ # messages = [
195
+ # {"role": "system", "content": sys_promot},
196
+ # {"role": "user", "content": []}]
197
+ # for img in img_list:
198
+ # messages[1]["content"].append(
199
+ # {"image": f"data:image/png;base64,{encode_image(img)}"})
200
+ # messages[1]["content"].append({"text": f"{prompt}"})
201
+
202
+ # response_format = kwargs.get('response_format', None)
203
+
204
+ # response = dashscope.MultiModalConversation.call(
205
+ # api_key=api_key,
206
+ # model=model, # For example, use qwen-plus here. You can change the model name as needed. Model list: https://help.aliyun.com/zh/model-studio/getting-started/models
207
+ # messages=messages,
208
+ # result_format='message',
209
+ # response_format=response_format,
210
+ # )
211
+
212
+ # if response.status_code == 200:
213
+ # return response.output.choices[0].message.content[0]['text']
214
+ # else:
215
+ # raise Exception(f'Failed to post: {response}')
 
 
216
 
217
  # --- Model Loading ---
218
  dtype = torch.bfloat16
 
232
  seed=42,
233
  randomize_seed=False,
234
  true_guidance_scale=1.0,
235
+ num_inference_steps=8,
236
  height=None,
237
  width=None,
238
  rewrite_prompt=True,
 
271
  print(f"Negative Prompt: '{negative_prompt}'")
272
  print(f"Seed: {seed}, Steps: {num_inference_steps}, Guidance: {true_guidance_scale}, Size: {width}x{height}")
273
  if rewrite_prompt and len(pil_images) > 0:
274
+ prompt = polish_prompt_hf(prompt, pil_images)
275
  print(f"Rewritten Prompt: {prompt}")
276
 
277