Spaces:
Sleeping
Sleeping
| from time import time | |
| from io import BytesIO | |
| import torch | |
| import streamlit as st | |
| import streamlit.components.v1 as components | |
| import numpy as np | |
| import torch | |
| import logging | |
| from os import environ | |
| from transformers import OwlViTProcessor, OwlViTForObjectDetection | |
| from bot import Bot, Message | |
| from myscaledb import Client | |
| from classifier import Classifier, prompt2vec, tune, SplitLayer | |
| from query_model import simple_query, topk_obj_query, rev_query | |
| from card_model import card, obj_card, style | |
| from box_utils import postprocess | |
| environ["TOKENIZERS_PARALLELISM"] = "true" | |
| OBJ_DB_NAME = "mqdb_demo.coco_owl_vit_b_32_objects" | |
| IMG_DB_NAME = "mqdb_demo.coco_owl_vit_b_32_images" | |
| MODEL_ID = "google/owlvit-base-patch32" | |
| DIMS = 512 | |
| qtime = 0 | |
| def build_model(name="google/owlvit-base-patch32"): | |
| """Model builder function | |
| Args: | |
| name (str, optional): Name for HuggingFace OwlViT model. Defaults to "google/owlvit-base-patch32". | |
| Returns: | |
| (model, processor): OwlViT model and its processor for both image and text | |
| """ | |
| device = "cpu" | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| model = OwlViTForObjectDetection.from_pretrained(name).to(device) | |
| processor = OwlViTProcessor.from_pretrained(name) | |
| return model, processor | |
| def init_owlvit(): | |
| """Initialize OwlViT Model | |
| Returns: | |
| model, processor | |
| """ | |
| model, processor = build_model(MODEL_ID) | |
| return model, processor | |
| def init_db(): | |
| """Initialize the Database Connection | |
| Returns: | |
| meta_field: Meta field that records if an image is viewed or not | |
| client: Database connection object | |
| """ | |
| meta = [] | |
| client = Client( | |
| url=st.secrets["DB_URL"], user=st.secrets["USER"], password=st.secrets["PASSWD"] | |
| ) | |
| # We can check if the connection is alive | |
| assert client.is_alive() | |
| return meta, client | |
| def refresh_index(): | |
| """Clean the session""" | |
| del st.session_state["meta"] | |
| st.session_state.meta = [] | |
| st.session_state.query_num = 0 | |
| logging.info(f"Refresh for '{st.session_state.meta}'") | |
| # Need to clear singleton function with streamlit API | |
| init_db.clear() | |
| # refresh session states | |
| st.session_state.meta, st.session_state.index = init_db() | |
| if "clf" in st.session_state: | |
| del st.session_state.clf | |
| if "xq" in st.session_state: | |
| del st.session_state.xq | |
| if "topk_img_id" in st.session_state: | |
| del st.session_state.topk_img_id | |
| def query(xq, exclude_list=None): | |
| """Query matched w.r.t a given vector | |
| In this part, we will retrieve A LOT OF data from the server, | |
| including TopK boxes and their embeddings, the counterpart of non-TopK boxes in TopK images. | |
| Args: | |
| xq (numpy.ndarray or list of floats): Query vector | |
| Returns: | |
| matches: list of Records object. Keys referrring to selected columns group by images. | |
| Exclude the user's viewlist. | |
| img_matches: list of Records object. Containing other non-TopK but hit objects among TopK images. | |
| side_matches: list of Records object. Containing REAL TopK objects disregard the user's view history | |
| """ | |
| attempt = 0 | |
| xq = xq | |
| xq = xq / np.linalg.norm(xq, axis=-1, ord=2, keepdims=True) | |
| status_bar = [st.empty(), st.empty()] | |
| status_bar[0].write("Retrieving Another TopK Images...") | |
| pbar = status_bar[1].progress(0) | |
| while attempt < 3: | |
| try: | |
| matches = topk_obj_query( | |
| st.session_state.index, | |
| xq, | |
| IMG_DB_NAME, | |
| OBJ_DB_NAME, | |
| exclude_list=exclude_list, | |
| topk=5000, | |
| ) | |
| img_ids = [r["img_id"] for r in matches] | |
| if "topk_img_id" not in st.session_state: | |
| st.session_state.topk_img_id = img_ids | |
| status_bar[0].write("Retrieving TopK Images...") | |
| pbar.progress(25) | |
| o_matches = rev_query( | |
| st.session_state.index, | |
| xq, | |
| st.session_state.topk_img_id, | |
| IMG_DB_NAME, | |
| OBJ_DB_NAME, | |
| thresh=0.1, | |
| ) | |
| status_bar[0].write("Retrieving TopKs Objects...") | |
| pbar.progress(50) | |
| side_matches = simple_query( | |
| st.session_state.index, | |
| xq, | |
| IMG_DB_NAME, | |
| OBJ_DB_NAME, | |
| thresh=-1, | |
| topk=5000, | |
| ) | |
| status_bar[0].write("Retrieving Non-TopK in Another TopK Images...") | |
| pbar.progress(75) | |
| if len(img_ids) > 0: | |
| img_matches = rev_query( | |
| st.session_state.index, | |
| xq, | |
| img_ids, | |
| IMG_DB_NAME, | |
| OBJ_DB_NAME, | |
| thresh=0.1, | |
| ) | |
| else: | |
| img_matches = [] | |
| status_bar[0].write("DONE!") | |
| pbar.progress(100) | |
| break | |
| except Exception as e: | |
| # force reload if we have trouble on connections or something else | |
| logging.warning(str(e)) | |
| st.session_state.meta, st.session_state.index = init_db() | |
| attempt += 1 | |
| matches = [] | |
| _ = [s.empty() for s in status_bar] | |
| if len(matches) == 0: | |
| logging.error(f"No matches found for '{OBJ_DB_NAME}'") | |
| return matches, img_matches, side_matches, o_matches | |
| def init_random_query(): | |
| """Initialize a random query vector | |
| Returns: | |
| xq: a random vector | |
| """ | |
| xq = np.random.rand(1, DIMS) | |
| xq /= np.linalg.norm(xq, keepdims=True, axis=-1) | |
| return xq | |
| def submit(meta): | |
| """Tune the model w.r.t given score from user.""" | |
| # Only updating the meta if the train button is pressed | |
| st.session_state.meta.extend(meta) | |
| st.session_state.step += 1 | |
| matches = st.session_state.matched_boxes | |
| X, y = list( | |
| zip( | |
| *( | |
| ( | |
| v[0], | |
| st.session_state.text_prompts.index(st.session_state[f"label-{i}"]), | |
| ) | |
| for i, v in matches.items() | |
| ) | |
| ) | |
| ) | |
| st.session_state.xq = tune( | |
| st.session_state.clf, X, y, iters=int(st.session_state.iters) | |
| ) | |
| ( | |
| st.session_state.matches, | |
| st.session_state.img_matches, | |
| st.session_state.side_matches, | |
| st.session_state.o_matches, | |
| ) = query(st.session_state.xq, st.session_state.meta) | |
| # st.set_page_config(layout="wide") | |
| # To hack the streamlit style we define our own style. | |
| # Boxes are drawn in SVGs. | |
| st.write(style(), unsafe_allow_html=True) | |
| bot = Bot(app_name="HF OwlViT", enabled=True, bot_key=st.secrets['BOT_KEY']) | |
| try: | |
| with st.spinner("Connecting DB..."): | |
| st.session_state.meta, st.session_state.index = init_db() | |
| with st.spinner("Loading Models..."): | |
| # Initialize model | |
| model, tokenizer = init_owlvit() | |
| # If its a fresh start... (query not set) | |
| if "xq" not in st.session_state: | |
| with st.container(): | |
| st.title("Object Detection Safari") | |
| start = [st.empty() for _ in range(8)] | |
| start[0].info( | |
| """ | |
| We extracted boxes from **287,104** images in COCO Dataset, including its train / val / test / | |
| unlabeled images, collecting **165,371,904 boxes** which are then filtered with common prompts. | |
| You can search with almost any words or phrases you can think of. Please enjoy your journey of | |
| an adventure to COCO. | |
| """ | |
| ) | |
| prompt = start[1].text_input( | |
| "Prompt:", | |
| value="", | |
| placeholder="Examples: football, billboard, stop sign, watermark ...", | |
| ) | |
| with start[2].container(): | |
| st.write( | |
| "You can search with multiple keywords. Plese separate with commas but with no space." | |
| ) | |
| st.write("For example: `cat,dog,tree`") | |
| st.markdown( | |
| """ | |
| <p style="color:gray;"> Don\'t know what to search? Try <b>Random</b>!</p> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| upld_model = start[4].file_uploader( | |
| "Or you can upload your previous run!", type="onnx" | |
| ) | |
| upld_btn = start[5].button( | |
| "Use Loaded Weights", disabled=upld_model is None, on_click=refresh_index | |
| ) | |
| with start[3]: | |
| col = st.columns(8) | |
| has_no_prompt = len(prompt) == 0 and upld_model is None | |
| prompt_xq = col[6].button( | |
| "Prompt", disabled=len(prompt) == 0, on_click=refresh_index | |
| ) | |
| random_xq = col[7].button( | |
| "Random", disabled=not has_no_prompt, on_click=refresh_index | |
| ) | |
| matches = [] | |
| img_matches = [] | |
| if random_xq: | |
| xq = init_random_query() | |
| st.session_state.xq = xq | |
| prompt = "unknown" | |
| st.session_state.text_prompts = prompt.split(",") + ["none"] | |
| _ = [elem.empty() for elem in start] | |
| t0 = time() | |
| ( | |
| st.session_state.matches, | |
| st.session_state.img_matches, | |
| st.session_state.side_matches, | |
| st.session_state.o_matches, | |
| ) = query(st.session_state.xq, st.session_state.meta) | |
| t1 = time() | |
| qtime = (t1 - t0) * 1000 | |
| elif prompt_xq or upld_btn: | |
| if upld_model is not None: | |
| import onnx | |
| from onnx import numpy_helper | |
| _model = onnx.load(upld_model) | |
| st.session_state.text_prompts = [ | |
| node.name for node in _model.graph.output | |
| ] + ["none"] | |
| weights = _model.graph.initializer | |
| xq = numpy_helper.to_array(weights[0]).T | |
| assert ( | |
| xq.shape[0] == len(st.session_state.text_prompts) - 1 | |
| and xq.shape[1] == DIMS | |
| ) | |
| st.session_state.xq = xq | |
| _ = [elem.empty() for elem in start] | |
| else: | |
| logging.info(f"Input prompt is {prompt}") | |
| st.session_state.text_prompts = prompt.split(",") + ["none"] | |
| input_ids, xq = prompt2vec( | |
| st.session_state.text_prompts[:-1], model, tokenizer | |
| ) | |
| st.session_state.xq = xq | |
| _ = [elem.empty() for elem in start] | |
| t0 = time() | |
| ( | |
| st.session_state.matches, | |
| st.session_state.img_matches, | |
| st.session_state.side_matches, | |
| st.session_state.o_matches, | |
| ) = query(st.session_state.xq, st.session_state.meta) | |
| t1 = time() | |
| qtime = (t1 - t0) * 1000 | |
| # If its not a fresh start (query is set) | |
| if "xq" in st.session_state: | |
| o_matches = st.session_state.o_matches | |
| side_matches = st.session_state.side_matches | |
| img_matches = st.session_state.img_matches | |
| matches = st.session_state.matches | |
| # initialize classifier | |
| if "clf" not in st.session_state: | |
| st.session_state.clf = Classifier(st.session_state.index, OBJ_DB_NAME, st.session_state.xq) | |
| st.session_state.step = 0 | |
| if qtime > 0: | |
| st.info( | |
| "Query done in {0:.2f} ms and returned {1:d} images with {2:d} boxes".format( | |
| qtime, | |
| len(matches), | |
| sum( | |
| [ | |
| len(m["box_id"]) + len(im["box_id"]) | |
| for m, im in zip(matches, img_matches) | |
| ] | |
| ), | |
| ) | |
| ) | |
| lnprob = torch.nn.Linear(st.session_state.xq.shape[1], st.session_state.xq.shape[0], bias=False) | |
| lnprob.weight = torch.nn.Parameter(st.session_state.clf.weight) | |
| # export the model into executable ONNX | |
| st.session_state.dnld_model = BytesIO() | |
| torch.onnx.export( | |
| torch.nn.Sequential(lnprob, SplitLayer()), | |
| torch.zeros([1, len(st.session_state.xq[0])]), | |
| st.session_state.dnld_model, | |
| input_names=["input"], | |
| output_names=st.session_state.text_prompts[:-1], | |
| ) | |
| dnld_nam = st.text_input( | |
| "Download Name:", | |
| f'{("_".join([i.replace(" ", "-") for i in st.session_state.text_prompts[:-1]]) if "text_prompts" in st.session_state else "model")}.onnx', | |
| max_chars=50, | |
| ) | |
| dnld_btn = st.download_button( | |
| "Download your classifier!", st.session_state.dnld_model, dnld_nam | |
| ) | |
| # build up a sidebar to display REAL TopK in DB | |
| # this will change during user's finetune. But sometime it would lead to bad results | |
| side_bar_len = min(240 // len(st.session_state.text_prompts), 120) | |
| with st.sidebar: | |
| with st.expander("Top-K Images"): | |
| with st.container(): | |
| boxes_w_img, _ = postprocess( | |
| o_matches, st.session_state.text_prompts, o_matches, | |
| agnostic_ratio=1-0.6**(st.session_state.step+1), | |
| class_ratio=1-0.2**(st.session_state.step+1) | |
| ) | |
| boxes_w_img = sorted(boxes_w_img, key=lambda x: x[4], reverse=True) | |
| for img_id, img_url, img_w, img_h, img_score, boxes in boxes_w_img: | |
| args = img_url, img_w, img_h, boxes | |
| st.write(card(*args), unsafe_allow_html=True) | |
| with st.expander("Top-K Objects", expanded=True): | |
| side_cols = st.columns(len(st.session_state.text_prompts[:-1])) | |
| for _cols, m in zip(side_cols, side_matches): | |
| with _cols.container(): | |
| for cx, cy, w, h, logit, img_url, img_w, img_h in zip( | |
| m["cx"], | |
| m["cy"], | |
| m["w"], | |
| m["h"], | |
| m["logit"], | |
| m["img_url"], | |
| m["img_w"], | |
| m["img_h"], | |
| ): | |
| st.write( | |
| "{:s}: {:.4f}".format( | |
| st.session_state.text_prompts[m["label"]], logit | |
| ) | |
| ) | |
| _html = obj_card( | |
| img_url, img_w, img_h, cx, cy, w, h, dst_len=side_bar_len | |
| ) | |
| components.html(_html, side_bar_len, side_bar_len) | |
| with st.container(): | |
| # Here let the user interact with batch labeling | |
| with st.form("batch", clear_on_submit=False): | |
| col = st.columns([1, 9]) | |
| # If there is nothing to show about | |
| if len(matches) <= 0: | |
| st.warning( | |
| "Oops! We didn't find anything relevant to your query! Pleas try another one :/" | |
| ) | |
| else: | |
| st.session_state.iters = st.slider( | |
| "Number of Iterations to Update", | |
| min_value=0, | |
| max_value=10, | |
| step=1, | |
| value=2, | |
| ) | |
| # No matter what happened the user wants a way back | |
| col[1].form_submit_button("Choose a new prompt", on_click=refresh_index) | |
| # If there are things to show | |
| if len(matches) > 0: | |
| with st.container(): | |
| prompt_labels = st.session_state.text_prompts | |
| # Post processing boxes regarding to their score, intersection | |
| boxes_w_img, meta = postprocess( | |
| matches, st.session_state.text_prompts, img_matches, | |
| agnostic_ratio=1-0.6**(st.session_state.step+1), | |
| class_ratio=1-0.2**(st.session_state.step+1) | |
| ) | |
| # Sort the result according to their relavancy | |
| boxes_w_img = sorted(boxes_w_img, key=lambda x: x[4], reverse=True) | |
| st.session_state.matched_boxes = {} | |
| # For each images in the retrieved images, DISPLAY | |
| for img_id, img_url, img_w, img_h, img_score, boxes in boxes_w_img: | |
| # prepare inputs for training | |
| st.session_state.matched_boxes.update({b[0]: b for b in boxes}) | |
| args = img_url, img_w, img_h, boxes | |
| # display boxes | |
| with st.expander( | |
| "{:s}: {:.4f}".format(img_id, img_score), expanded=True | |
| ): | |
| ind_b = 0 | |
| # 4 columns: (img, obj, obj, obj) | |
| img_row = st.columns([4, 2, 2, 2]) | |
| img_row[0].write(card(*args), unsafe_allow_html=True) | |
| # crop objects out of the original image | |
| for b in boxes: | |
| _id, cx, cy, w, h, label, logit, is_selected = b[:8] | |
| with img_row[1 + ind_b % 3].container(): | |
| st.write("{:s}: {:.4f}".format(label, logit)) | |
| # quite hacky: with streamlit components API | |
| _html = obj_card( | |
| img_url, img_w, img_h, *b[1:5], dst_len=120 | |
| ) | |
| components.html(_html, 120, 120) | |
| # the user will choose the right label of the given object | |
| st.selectbox( | |
| "Class", | |
| prompt_labels, | |
| index=prompt_labels.index(label), | |
| key=f"label-{_id}", | |
| ) | |
| ind_b += 1 | |
| col[0].form_submit_button("Train!", on_click=lambda: submit(meta)) | |
| except Exception as e: | |
| msg = Message() | |
| msg.content = str(e.with_traceback(None)) | |
| msg.type_hint = str(type(e).__name__) | |
| bot.incident(msg) | |