Spaces:
Runtime error
Runtime error
ttengwang
commited on
Commit
·
eabdb1c
1
Parent(s):
12dc496
improve chat box; add a enable_wiki button
Browse files- app.py +64 -38
- caption_anything.py +2 -2
- text_refiner/text_refiner.py +8 -6
app.py
CHANGED
|
@@ -120,28 +120,44 @@ def update_click_state(click_state, caption, click_mode):
|
|
| 120 |
raise NotImplementedError
|
| 121 |
|
| 122 |
|
| 123 |
-
def chat_with_points(chat_input, click_state, state, text_refiner):
|
| 124 |
if text_refiner is None:
|
| 125 |
response = "Text refiner is not initilzed, please input openai api key."
|
| 126 |
state = state + [(chat_input, response)]
|
| 127 |
-
return state, state
|
| 128 |
|
| 129 |
points, labels, captions = click_state
|
| 130 |
-
# point_chat_prompt = "I want you act as a chat bot in terms of image. I will give you some points (w, h) in the image and tell you what happed on the point in natural language. Note that (0, 0) refers to the top-left corner of the image, w refers to the width and h refers the height. You should chat with me based on the fact in the image instead of imagination. Now I tell you the points with their visual description:\n{points_with_caps}\nNow begin chatting!
|
|
|
|
|
|
|
| 131 |
# # "The image is of width {width} and height {height}."
|
| 132 |
-
point_chat_prompt = "
|
| 133 |
prev_visual_context = ""
|
| 134 |
-
pos_points = [
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
response = text_refiner.llm(chat_prompt)
|
| 141 |
state = state + [(chat_input, response)]
|
| 142 |
-
|
|
|
|
| 143 |
|
| 144 |
-
def inference_seg_cap(image_input, point_prompt, click_mode, language, sentiment, factuality,
|
| 145 |
length, image_embedding, state, click_state, original_size, input_size, text_refiner, evt:gr.SelectData):
|
| 146 |
|
| 147 |
model = build_caption_anything_with_models(
|
|
@@ -173,11 +189,12 @@ def inference_seg_cap(image_input, point_prompt, click_mode, language, sentiment
|
|
| 173 |
prompt = get_prompt(coordinate, click_state, click_mode)
|
| 174 |
print('prompt: ', prompt, 'controls: ', controls)
|
| 175 |
|
| 176 |
-
|
| 177 |
-
|
|
|
|
| 178 |
# for k, v in out['generated_captions'].items():
|
| 179 |
# state = state + [(f'{k}: {v}', None)]
|
| 180 |
-
state = state + [("
|
| 181 |
wiki = out['generated_captions'].get('wiki', "")
|
| 182 |
|
| 183 |
update_click_state(click_state, out['generated_captions']['raw_caption'], click_mode)
|
|
@@ -191,15 +208,18 @@ def inference_seg_cap(image_input, point_prompt, click_mode, language, sentiment
|
|
| 191 |
|
| 192 |
yield state, state, click_state, chat_input, image_input, wiki
|
| 193 |
if not args.disable_gpt and model.text_refiner:
|
| 194 |
-
refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'])
|
| 195 |
# new_cap = 'Original: ' + text + '. Refined: ' + refined_caption['caption']
|
| 196 |
new_cap = refined_caption['caption']
|
|
|
|
|
|
|
| 197 |
refined_image_input = create_bubble_frame(origin_image_input, new_cap, (evt.index[0], evt.index[1]))
|
| 198 |
yield state, state, click_state, chat_input, refined_image_input, wiki
|
| 199 |
|
| 200 |
|
| 201 |
def upload_callback(image_input, state):
|
| 202 |
-
state = [] + [('Image size: ' + str(image_input.size)
|
|
|
|
| 203 |
click_state = [[], [], []]
|
| 204 |
res = 1024
|
| 205 |
width, height = image_input.size
|
|
@@ -219,7 +239,7 @@ def upload_callback(image_input, state):
|
|
| 219 |
image_embedding = model.segmenter.image_embedding
|
| 220 |
original_size = model.segmenter.predictor.original_size
|
| 221 |
input_size = model.segmenter.predictor.input_size
|
| 222 |
-
return state, state, image_input, click_state, image_input, image_embedding, original_size, input_size
|
| 223 |
|
| 224 |
with gr.Blocks(
|
| 225 |
css='''
|
|
@@ -229,6 +249,7 @@ with gr.Blocks(
|
|
| 229 |
) as iface:
|
| 230 |
state = gr.State([])
|
| 231 |
click_state = gr.State([[],[],[]])
|
|
|
|
| 232 |
origin_image = gr.State(None)
|
| 233 |
image_embedding = gr.State(None)
|
| 234 |
text_refiner = gr.State(None)
|
|
@@ -260,14 +281,13 @@ with gr.Blocks(
|
|
| 260 |
clear_button_image = gr.Button(value="Clear Image", interactive=True)
|
| 261 |
with gr.Column(visible=False) as modules_need_gpt:
|
| 262 |
with gr.Row(scale=1.0):
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
)
|
| 271 |
with gr.Row(scale=1.0):
|
| 272 |
factuality = gr.Radio(
|
| 273 |
choices=["Factual", "Imagination"],
|
|
@@ -281,8 +301,13 @@ with gr.Blocks(
|
|
| 281 |
value=10,
|
| 282 |
step=1,
|
| 283 |
interactive=True,
|
| 284 |
-
label="Length",
|
| 285 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
with gr.Column(visible=True) as modules_not_need_gpt3:
|
| 287 |
gr.Examples(
|
| 288 |
examples=examples,
|
|
@@ -303,7 +328,7 @@ with gr.Blocks(
|
|
| 303 |
with gr.Column(visible=False) as modules_not_need_gpt2:
|
| 304 |
chatbot = gr.Chatbot(label="Chat about Selected Object",).style(height=550,scale=0.5)
|
| 305 |
with gr.Column(visible=False) as modules_need_gpt3:
|
| 306 |
-
chat_input = gr.Textbox(
|
| 307 |
with gr.Row():
|
| 308 |
clear_button_text = gr.Button(value="Clear Text", interactive=True)
|
| 309 |
submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
|
|
@@ -320,30 +345,30 @@ with gr.Blocks(
|
|
| 320 |
show_progress=False
|
| 321 |
)
|
| 322 |
clear_button_image.click(
|
| 323 |
-
lambda: (None, [], [], [[], [], []], "", ""),
|
| 324 |
[],
|
| 325 |
-
[image_input, chatbot, state, click_state, wiki_output, origin_image],
|
| 326 |
queue=False,
|
| 327 |
show_progress=False
|
| 328 |
)
|
| 329 |
clear_button_text.click(
|
| 330 |
-
lambda: ([], [], [[], [], []]),
|
| 331 |
[],
|
| 332 |
-
[chatbot, state, click_state],
|
| 333 |
queue=False,
|
| 334 |
show_progress=False
|
| 335 |
)
|
| 336 |
image_input.clear(
|
| 337 |
-
lambda: (None, [], [], [[], [], []], "", ""),
|
| 338 |
[],
|
| 339 |
-
[image_input, chatbot, state, click_state, wiki_output, origin_image],
|
| 340 |
queue=False,
|
| 341 |
show_progress=False
|
| 342 |
)
|
| 343 |
|
| 344 |
-
image_input.upload(upload_callback,[image_input, state], [chatbot, state, origin_image, click_state, image_input, image_embedding, original_size, input_size])
|
| 345 |
-
chat_input.submit(chat_with_points, [chat_input, click_state, state, text_refiner], [chatbot, state])
|
| 346 |
-
example_image.change(upload_callback,[example_image, state], [
|
| 347 |
|
| 348 |
# select coordinate
|
| 349 |
image_input.select(inference_seg_cap,
|
|
@@ -351,6 +376,7 @@ with gr.Blocks(
|
|
| 351 |
origin_image,
|
| 352 |
point_prompt,
|
| 353 |
click_mode,
|
|
|
|
| 354 |
language,
|
| 355 |
sentiment,
|
| 356 |
factuality,
|
|
|
|
| 120 |
raise NotImplementedError
|
| 121 |
|
| 122 |
|
| 123 |
+
def chat_with_points(chat_input, click_state, chat_state, state, text_refiner):
|
| 124 |
if text_refiner is None:
|
| 125 |
response = "Text refiner is not initilzed, please input openai api key."
|
| 126 |
state = state + [(chat_input, response)]
|
| 127 |
+
return state, state, chat_state
|
| 128 |
|
| 129 |
points, labels, captions = click_state
|
| 130 |
+
# point_chat_prompt = "I want you act as a chat bot in terms of image. I will give you some points (w, h) in the image and tell you what happed on the point in natural language. Note that (0, 0) refers to the top-left corner of the image, w refers to the width and h refers the height. You should chat with me based on the fact in the image instead of imagination. Now I tell you the points with their visual description:\n{points_with_caps}\nNow begin chatting!"
|
| 131 |
+
suffix = '\nHuman: {chat_input}\nAI: '
|
| 132 |
+
qa_template = '\nHuman: {q}\nAI: {a}'
|
| 133 |
# # "The image is of width {width} and height {height}."
|
| 134 |
+
point_chat_prompt = "I am an AI trained to chat with you about an image based on specific points (w, h) you provide, along with their visual descriptions. Please note that (0, 0) refers to the top-left corner of the image, w refers to the width, and h refers to the height. Here are the points and their descriptions you've given me: {points_with_caps} \n Now, let's chat!"
|
| 135 |
prev_visual_context = ""
|
| 136 |
+
pos_points = []
|
| 137 |
+
pos_captions = []
|
| 138 |
+
for i in range(len(points)):
|
| 139 |
+
if labels[i] == 1:
|
| 140 |
+
pos_points.append(f"({points[i][0]}, {points[i][0]})")
|
| 141 |
+
pos_captions.append(captions[i])
|
| 142 |
+
prev_visual_context = prev_visual_context + '\n' + 'Points: ' +', '.join(pos_points) + '. Description: ' + pos_captions[-1]
|
| 143 |
+
|
| 144 |
+
context_length_thres = 500
|
| 145 |
+
prev_history = ""
|
| 146 |
+
for i in range(len(chat_state)):
|
| 147 |
+
q, a = chat_state[i]
|
| 148 |
+
if len(prev_history) < context_length_thres:
|
| 149 |
+
prev_history = prev_history + qa_template.format(**{"q": q, "a": a})
|
| 150 |
+
else:
|
| 151 |
+
break
|
| 152 |
+
|
| 153 |
+
chat_prompt = point_chat_prompt.format(**{"points_with_caps": prev_visual_context}) + prev_history + suffix.format(**{"chat_input": chat_input})
|
| 154 |
+
print('\nchat_prompt: ', chat_prompt)
|
| 155 |
response = text_refiner.llm(chat_prompt)
|
| 156 |
state = state + [(chat_input, response)]
|
| 157 |
+
chat_state = chat_state + [(chat_input, response)]
|
| 158 |
+
return state, state, chat_state
|
| 159 |
|
| 160 |
+
def inference_seg_cap(image_input, point_prompt, click_mode, enable_wiki, language, sentiment, factuality,
|
| 161 |
length, image_embedding, state, click_state, original_size, input_size, text_refiner, evt:gr.SelectData):
|
| 162 |
|
| 163 |
model = build_caption_anything_with_models(
|
|
|
|
| 189 |
prompt = get_prompt(coordinate, click_state, click_mode)
|
| 190 |
print('prompt: ', prompt, 'controls: ', controls)
|
| 191 |
|
| 192 |
+
enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
|
| 193 |
+
out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki)
|
| 194 |
+
state = state + [("Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]), None)]
|
| 195 |
# for k, v in out['generated_captions'].items():
|
| 196 |
# state = state + [(f'{k}: {v}', None)]
|
| 197 |
+
state = state + [(None, "raw_caption: {}".format(out['generated_captions']['raw_caption']))]
|
| 198 |
wiki = out['generated_captions'].get('wiki', "")
|
| 199 |
|
| 200 |
update_click_state(click_state, out['generated_captions']['raw_caption'], click_mode)
|
|
|
|
| 208 |
|
| 209 |
yield state, state, click_state, chat_input, image_input, wiki
|
| 210 |
if not args.disable_gpt and model.text_refiner:
|
| 211 |
+
refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'], enable_wiki=enable_wiki)
|
| 212 |
# new_cap = 'Original: ' + text + '. Refined: ' + refined_caption['caption']
|
| 213 |
new_cap = refined_caption['caption']
|
| 214 |
+
wiki = refined_caption['wiki']
|
| 215 |
+
state = state + [(None, f"caption: {new_cap}")]
|
| 216 |
refined_image_input = create_bubble_frame(origin_image_input, new_cap, (evt.index[0], evt.index[1]))
|
| 217 |
yield state, state, click_state, chat_input, refined_image_input, wiki
|
| 218 |
|
| 219 |
|
| 220 |
def upload_callback(image_input, state):
|
| 221 |
+
state = [] + [(None, 'Image size: ' + str(image_input.size))]
|
| 222 |
+
chat_state = []
|
| 223 |
click_state = [[], [], []]
|
| 224 |
res = 1024
|
| 225 |
width, height = image_input.size
|
|
|
|
| 239 |
image_embedding = model.segmenter.image_embedding
|
| 240 |
original_size = model.segmenter.predictor.original_size
|
| 241 |
input_size = model.segmenter.predictor.input_size
|
| 242 |
+
return state, state, chat_state, image_input, click_state, image_input, image_embedding, original_size, input_size
|
| 243 |
|
| 244 |
with gr.Blocks(
|
| 245 |
css='''
|
|
|
|
| 249 |
) as iface:
|
| 250 |
state = gr.State([])
|
| 251 |
click_state = gr.State([[],[],[]])
|
| 252 |
+
chat_state = gr.State([])
|
| 253 |
origin_image = gr.State(None)
|
| 254 |
image_embedding = gr.State(None)
|
| 255 |
text_refiner = gr.State(None)
|
|
|
|
| 281 |
clear_button_image = gr.Button(value="Clear Image", interactive=True)
|
| 282 |
with gr.Column(visible=False) as modules_need_gpt:
|
| 283 |
with gr.Row(scale=1.0):
|
| 284 |
+
language = gr.Dropdown(['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"], value="English", label="Language", interactive=True)
|
| 285 |
+
sentiment = gr.Radio(
|
| 286 |
+
choices=["Positive", "Natural", "Negative"],
|
| 287 |
+
value="Natural",
|
| 288 |
+
label="Sentiment",
|
| 289 |
+
interactive=True,
|
| 290 |
+
)
|
|
|
|
| 291 |
with gr.Row(scale=1.0):
|
| 292 |
factuality = gr.Radio(
|
| 293 |
choices=["Factual", "Imagination"],
|
|
|
|
| 301 |
value=10,
|
| 302 |
step=1,
|
| 303 |
interactive=True,
|
| 304 |
+
label="Generated Caption Length",
|
| 305 |
+
)
|
| 306 |
+
enable_wiki = gr.Radio(
|
| 307 |
+
choices=["Yes", "No"],
|
| 308 |
+
value="No",
|
| 309 |
+
label="Enable Wiki",
|
| 310 |
+
interactive=True)
|
| 311 |
with gr.Column(visible=True) as modules_not_need_gpt3:
|
| 312 |
gr.Examples(
|
| 313 |
examples=examples,
|
|
|
|
| 328 |
with gr.Column(visible=False) as modules_not_need_gpt2:
|
| 329 |
chatbot = gr.Chatbot(label="Chat about Selected Object",).style(height=550,scale=0.5)
|
| 330 |
with gr.Column(visible=False) as modules_need_gpt3:
|
| 331 |
+
chat_input = gr.Textbox(show_label=False, placeholder="Enter text and press Enter").style(container=False)
|
| 332 |
with gr.Row():
|
| 333 |
clear_button_text = gr.Button(value="Clear Text", interactive=True)
|
| 334 |
submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
|
|
|
|
| 345 |
show_progress=False
|
| 346 |
)
|
| 347 |
clear_button_image.click(
|
| 348 |
+
lambda: (None, [], [], [], [[], [], []], "", ""),
|
| 349 |
[],
|
| 350 |
+
[image_input, chatbot, state, chat_state, click_state, wiki_output, origin_image],
|
| 351 |
queue=False,
|
| 352 |
show_progress=False
|
| 353 |
)
|
| 354 |
clear_button_text.click(
|
| 355 |
+
lambda: ([], [], [[], [], [], []], []),
|
| 356 |
[],
|
| 357 |
+
[chatbot, state, click_state, chat_state],
|
| 358 |
queue=False,
|
| 359 |
show_progress=False
|
| 360 |
)
|
| 361 |
image_input.clear(
|
| 362 |
+
lambda: (None, [], [], [], [[], [], []], "", ""),
|
| 363 |
[],
|
| 364 |
+
[image_input, chatbot, state, chat_state, click_state, wiki_output, origin_image],
|
| 365 |
queue=False,
|
| 366 |
show_progress=False
|
| 367 |
)
|
| 368 |
|
| 369 |
+
image_input.upload(upload_callback,[image_input, state], [chatbot, state, chat_state, origin_image, click_state, image_input, image_embedding, original_size, input_size])
|
| 370 |
+
chat_input.submit(chat_with_points, [chat_input, click_state, chat_state, state, text_refiner], [chatbot, state, chat_state])
|
| 371 |
+
example_image.change(upload_callback,[example_image, state], [chatbot, state, chat_state, origin_image, click_state, image_input, image_embedding, original_size, input_size])
|
| 372 |
|
| 373 |
# select coordinate
|
| 374 |
image_input.select(inference_seg_cap,
|
|
|
|
| 376 |
origin_image,
|
| 377 |
point_prompt,
|
| 378 |
click_mode,
|
| 379 |
+
enable_wiki,
|
| 380 |
language,
|
| 381 |
sentiment,
|
| 382 |
factuality,
|
caption_anything.py
CHANGED
|
@@ -30,7 +30,7 @@ class CaptionAnything():
|
|
| 30 |
self.text_refiner = None
|
| 31 |
print('OpenAI GPT is not available')
|
| 32 |
|
| 33 |
-
def inference(self, image, prompt, controls, disable_gpt=False):
|
| 34 |
# segment with prompt
|
| 35 |
print("CA prompt: ", prompt, "CA controls",controls)
|
| 36 |
seg_mask = self.segmenter.inference(image, prompt)[0, ...]
|
|
@@ -59,7 +59,7 @@ class CaptionAnything():
|
|
| 59 |
if self.args.context_captions:
|
| 60 |
context_captions.append(self.captioner.inference(image))
|
| 61 |
if not disable_gpt and self.text_refiner is not None:
|
| 62 |
-
refined_caption = self.text_refiner.inference(query=caption, controls=controls, context=context_captions)
|
| 63 |
else:
|
| 64 |
refined_caption = {'raw_caption': caption}
|
| 65 |
out = {'generated_captions': refined_caption,
|
|
|
|
| 30 |
self.text_refiner = None
|
| 31 |
print('OpenAI GPT is not available')
|
| 32 |
|
| 33 |
+
def inference(self, image, prompt, controls, disable_gpt=False, enable_wiki=False):
|
| 34 |
# segment with prompt
|
| 35 |
print("CA prompt: ", prompt, "CA controls",controls)
|
| 36 |
seg_mask = self.segmenter.inference(image, prompt)[0, ...]
|
|
|
|
| 59 |
if self.args.context_captions:
|
| 60 |
context_captions.append(self.captioner.inference(image))
|
| 61 |
if not disable_gpt and self.text_refiner is not None:
|
| 62 |
+
refined_caption = self.text_refiner.inference(query=caption, controls=controls, context=context_captions, enable_wiki=enable_wiki)
|
| 63 |
else:
|
| 64 |
refined_caption = {'raw_caption': caption}
|
| 65 |
out = {'generated_captions': refined_caption,
|
text_refiner/text_refiner.py
CHANGED
|
@@ -39,7 +39,7 @@ class TextRefiner:
|
|
| 39 |
print('prompt: ', input)
|
| 40 |
return input
|
| 41 |
|
| 42 |
-
def inference(self, query: str, controls: dict, context: list=[]):
|
| 43 |
"""
|
| 44 |
query: the caption of the region of interest, generated by captioner
|
| 45 |
controls: a dict of control singals, e.g., {"length": 5, "sentiment": "positive"}
|
|
@@ -58,15 +58,17 @@ class TextRefiner:
|
|
| 58 |
response = self.llm(input)
|
| 59 |
response = self.parse(response)
|
| 60 |
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
| 65 |
out = {
|
| 66 |
'raw_caption': query,
|
| 67 |
'caption': response,
|
| 68 |
'wiki': response_wiki
|
| 69 |
-
}
|
| 70 |
print(out)
|
| 71 |
return out
|
| 72 |
|
|
|
|
| 39 |
print('prompt: ', input)
|
| 40 |
return input
|
| 41 |
|
| 42 |
+
def inference(self, query: str, controls: dict, context: list=[], enable_wiki=False):
|
| 43 |
"""
|
| 44 |
query: the caption of the region of interest, generated by captioner
|
| 45 |
controls: a dict of control singals, e.g., {"length": 5, "sentiment": "positive"}
|
|
|
|
| 58 |
response = self.llm(input)
|
| 59 |
response = self.parse(response)
|
| 60 |
|
| 61 |
+
response_wiki = ""
|
| 62 |
+
if enable_wiki:
|
| 63 |
+
tmp_configs = {"query": query}
|
| 64 |
+
prompt_wiki = self.wiki_prompts.format(**tmp_configs)
|
| 65 |
+
response_wiki = self.llm(prompt_wiki)
|
| 66 |
+
response_wiki = self.parse2(response_wiki)
|
| 67 |
out = {
|
| 68 |
'raw_caption': query,
|
| 69 |
'caption': response,
|
| 70 |
'wiki': response_wiki
|
| 71 |
+
}
|
| 72 |
print(out)
|
| 73 |
return out
|
| 74 |
|