Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -6,7 +6,7 @@ from threading import Thread
|
|
| 6 |
# import time
|
| 7 |
import cv2
|
| 8 |
|
| 9 |
-
|
| 10 |
# import copy
|
| 11 |
import torch
|
| 12 |
|
|
@@ -34,8 +34,6 @@ from llava.mm_utils import (
|
|
| 34 |
|
| 35 |
from serve_constants import html_header
|
| 36 |
|
| 37 |
-
from PIL import Image
|
| 38 |
-
|
| 39 |
import requests
|
| 40 |
from PIL import Image
|
| 41 |
from io import BytesIO
|
|
@@ -46,6 +44,9 @@ import gradio_client
|
|
| 46 |
import subprocess
|
| 47 |
import sys
|
| 48 |
|
|
|
|
|
|
|
|
|
|
| 49 |
def install_gradio_4_35_0():
|
| 50 |
current_version = gr.__version__
|
| 51 |
if current_version != "4.35.0":
|
|
@@ -64,6 +65,11 @@ import gradio_client
|
|
| 64 |
print(f"Gradio version: {gr.__version__}")
|
| 65 |
print(f"Gradio-client version: {gradio_client.__version__}")
|
| 66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
class InferenceDemo(object):
|
| 68 |
def __init__(
|
| 69 |
self, args, model_path, tokenizer, model, image_processor, context_len
|
|
@@ -113,6 +119,16 @@ def is_valid_video_filename(name):
|
|
| 113 |
else:
|
| 114 |
return False
|
| 115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
def sample_frames(video_file, num_frames):
|
| 118 |
video = cv2.VideoCapture(video_file)
|
|
@@ -193,9 +209,14 @@ def bot(history):
|
|
| 193 |
if type(message[0]) is tuple:
|
| 194 |
images_this_term.append(message[0][0])
|
| 195 |
if is_valid_video_filename(message[0][0]):
|
|
|
|
|
|
|
| 196 |
num_new_images += our_chatbot.num_frames
|
| 197 |
-
|
|
|
|
| 198 |
num_new_images += 1
|
|
|
|
|
|
|
| 199 |
else:
|
| 200 |
num_new_images = 0
|
| 201 |
|
|
@@ -209,8 +230,11 @@ def bot(history):
|
|
| 209 |
for f in images_this_term:
|
| 210 |
if is_valid_video_filename(f):
|
| 211 |
image_list += sample_frames(f, our_chatbot.num_frames)
|
| 212 |
-
|
| 213 |
image_list.append(load_image(f))
|
|
|
|
|
|
|
|
|
|
| 214 |
image_tensor = [
|
| 215 |
our_chatbot.image_processor.preprocess(f, return_tensors="pt")["pixel_values"][
|
| 216 |
0
|
|
@@ -219,6 +243,24 @@ def bot(history):
|
|
| 219 |
.to(our_chatbot.model.device)
|
| 220 |
for f in image_list
|
| 221 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
|
| 223 |
image_tensor = torch.stack(image_tensor)
|
| 224 |
image_token = DEFAULT_IMAGE_TOKEN * num_new_images
|
|
@@ -280,7 +322,19 @@ def bot(history):
|
|
| 280 |
our_chatbot.conversation.messages[-1][-1] = outputs
|
| 281 |
|
| 282 |
history[-1] = [text, outputs]
|
| 283 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
return history
|
| 285 |
# generate_kwargs = dict(
|
| 286 |
# inputs=input_ids,
|
|
@@ -345,7 +399,7 @@ with gr.Blocks(
|
|
| 345 |
|
| 346 |
with gr.Column():
|
| 347 |
with gr.Row():
|
| 348 |
-
chatbot = gr.Chatbot([], elem_id="
|
| 349 |
|
| 350 |
with gr.Row():
|
| 351 |
upvote_btn = gr.Button(value="👍 Upvote", interactive=True)
|
|
@@ -560,8 +614,8 @@ if __name__ == "__main__":
|
|
| 560 |
argparser.add_argument("--model-base", type=str, default=None)
|
| 561 |
argparser.add_argument("--num-gpus", type=int, default=1)
|
| 562 |
argparser.add_argument("--conv-mode", type=str, default=None)
|
| 563 |
-
argparser.add_argument("--temperature", type=float, default=0.
|
| 564 |
-
argparser.add_argument("--max-new-tokens", type=int, default=
|
| 565 |
argparser.add_argument("--num_frames", type=int, default=16)
|
| 566 |
argparser.add_argument("--load-8bit", action="store_true")
|
| 567 |
argparser.add_argument("--load-4bit", action="store_true")
|
|
|
|
| 6 |
# import time
|
| 7 |
import cv2
|
| 8 |
|
| 9 |
+
import datetime
|
| 10 |
# import copy
|
| 11 |
import torch
|
| 12 |
|
|
|
|
| 34 |
|
| 35 |
from serve_constants import html_header
|
| 36 |
|
|
|
|
|
|
|
| 37 |
import requests
|
| 38 |
from PIL import Image
|
| 39 |
from io import BytesIO
|
|
|
|
| 44 |
import subprocess
|
| 45 |
import sys
|
| 46 |
|
| 47 |
+
external_log_dir = "./logs"
|
| 48 |
+
LOGDIR = external_log_dir
|
| 49 |
+
|
| 50 |
def install_gradio_4_35_0():
|
| 51 |
current_version = gr.__version__
|
| 52 |
if current_version != "4.35.0":
|
|
|
|
| 65 |
print(f"Gradio version: {gr.__version__}")
|
| 66 |
print(f"Gradio-client version: {gradio_client.__version__}")
|
| 67 |
|
| 68 |
+
def get_conv_log_filename():
|
| 69 |
+
t = datetime.datetime.now()
|
| 70 |
+
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_conv.json")
|
| 71 |
+
return name
|
| 72 |
+
|
| 73 |
class InferenceDemo(object):
|
| 74 |
def __init__(
|
| 75 |
self, args, model_path, tokenizer, model, image_processor, context_len
|
|
|
|
| 119 |
else:
|
| 120 |
return False
|
| 121 |
|
| 122 |
+
def is_valid_image_filename(name):
|
| 123 |
+
image_extensions = ["jpg", "jpeg", "png", "bmp", "gif", "tiff", "webp", "heic", "heif", "jfif", "svg", "eps", "raw"]
|
| 124 |
+
|
| 125 |
+
ext = name.split(".")[-1].lower()
|
| 126 |
+
|
| 127 |
+
if ext in image_extensions:
|
| 128 |
+
return True
|
| 129 |
+
else:
|
| 130 |
+
return False
|
| 131 |
+
|
| 132 |
|
| 133 |
def sample_frames(video_file, num_frames):
|
| 134 |
video = cv2.VideoCapture(video_file)
|
|
|
|
| 209 |
if type(message[0]) is tuple:
|
| 210 |
images_this_term.append(message[0][0])
|
| 211 |
if is_valid_video_filename(message[0][0]):
|
| 212 |
+
# 不接受视频
|
| 213 |
+
raise ValueError("Video is not supported")
|
| 214 |
num_new_images += our_chatbot.num_frames
|
| 215 |
+
elif is_valid_image_filename(message[0][0]):
|
| 216 |
+
print("#### Load image from local file",message[0][0])
|
| 217 |
num_new_images += 1
|
| 218 |
+
else:
|
| 219 |
+
raise ValueError("Invalid image file")
|
| 220 |
else:
|
| 221 |
num_new_images = 0
|
| 222 |
|
|
|
|
| 230 |
for f in images_this_term:
|
| 231 |
if is_valid_video_filename(f):
|
| 232 |
image_list += sample_frames(f, our_chatbot.num_frames)
|
| 233 |
+
elif is_valid_image_filename(f):
|
| 234 |
image_list.append(load_image(f))
|
| 235 |
+
else:
|
| 236 |
+
raise ValueError("Invalid image file")
|
| 237 |
+
|
| 238 |
image_tensor = [
|
| 239 |
our_chatbot.image_processor.preprocess(f, return_tensors="pt")["pixel_values"][
|
| 240 |
0
|
|
|
|
| 243 |
.to(our_chatbot.model.device)
|
| 244 |
for f in image_list
|
| 245 |
]
|
| 246 |
+
all_image_hash = []
|
| 247 |
+
for image_path in image_list:
|
| 248 |
+
with open(image_path, "rb") as image_file:
|
| 249 |
+
image_data = image_file.read()
|
| 250 |
+
image_hash = hashlib.md5(image_data).hexdigest()
|
| 251 |
+
all_image_hash.append(image_hash)
|
| 252 |
+
image = PIL.Image.open(image_path).convert("RGB")
|
| 253 |
+
all_images.append(image)
|
| 254 |
+
t = datetime.datetime.now()
|
| 255 |
+
filename = os.path.join(
|
| 256 |
+
LOGDIR,
|
| 257 |
+
"serve_images",
|
| 258 |
+
f"{t.year}-{t.month:02d}-{t.day:02d}",
|
| 259 |
+
f"{image_hash}.jpg",
|
| 260 |
+
)
|
| 261 |
+
if not os.path.isfile(filename):
|
| 262 |
+
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
| 263 |
+
image.save(filename)
|
| 264 |
|
| 265 |
image_tensor = torch.stack(image_tensor)
|
| 266 |
image_token = DEFAULT_IMAGE_TOKEN * num_new_images
|
|
|
|
| 322 |
our_chatbot.conversation.messages[-1][-1] = outputs
|
| 323 |
|
| 324 |
history[-1] = [text, outputs]
|
| 325 |
+
print("#### history",history)
|
| 326 |
+
|
| 327 |
+
with open(get_conv_log_filename(), "a") as fout:
|
| 328 |
+
data = {
|
| 329 |
+
"tstamp": round(finish_tstamp, 4),
|
| 330 |
+
"type": "chat",
|
| 331 |
+
"model": "Pangea-7b",
|
| 332 |
+
"start": round(start_tstamp, 4),
|
| 333 |
+
"finish": round(start_tstamp, 4),
|
| 334 |
+
"state": history,
|
| 335 |
+
"images": all_image_hash,
|
| 336 |
+
}
|
| 337 |
+
fout.write(json.dumps(data) + "\n")
|
| 338 |
return history
|
| 339 |
# generate_kwargs = dict(
|
| 340 |
# inputs=input_ids,
|
|
|
|
| 399 |
|
| 400 |
with gr.Column():
|
| 401 |
with gr.Row():
|
| 402 |
+
chatbot = gr.Chatbot([], elem_id="Pangea", bubble_full_width=False, height=750)
|
| 403 |
|
| 404 |
with gr.Row():
|
| 405 |
upvote_btn = gr.Button(value="👍 Upvote", interactive=True)
|
|
|
|
| 614 |
argparser.add_argument("--model-base", type=str, default=None)
|
| 615 |
argparser.add_argument("--num-gpus", type=int, default=1)
|
| 616 |
argparser.add_argument("--conv-mode", type=str, default=None)
|
| 617 |
+
argparser.add_argument("--temperature", type=float, default=0.7)
|
| 618 |
+
argparser.add_argument("--max-new-tokens", type=int, default=4096)
|
| 619 |
argparser.add_argument("--num_frames", type=int, default=16)
|
| 620 |
argparser.add_argument("--load-8bit", action="store_true")
|
| 621 |
argparser.add_argument("--load-4bit", action="store_true")
|