Spaces:
Running
Running
| import argparse | |
| import sys | |
| import os | |
| # import cv2 | |
| import glob | |
| import gradio as gr | |
| import numpy as np | |
| import json | |
| from PIL import Image | |
| from tqdm import tqdm | |
| from pathlib import Path | |
| import uvicorn | |
| from fastapi.staticfiles import StaticFiles | |
| import random | |
| import time | |
| import requests | |
| from fastapi import FastAPI | |
| from conversation import SeparatorStyle, conv_templates, default_conversation | |
| from utils import ( | |
| build_logger, | |
| moderation_msg, | |
| server_error_msg, | |
| ) | |
| from config import cur_conv | |
| logger = build_logger("gradio_web_server", "gradio_web_server.log") | |
| headers = {"Content-Type": "application/json"} | |
| # create a FastAPI app | |
| app = FastAPI() | |
| # # create a static directory to store the static files | |
| # static_dir = Path('/data/Multimodal-RAG/GenerativeAIExamples/ChatQnA/langchain/redis/chips-making-deals/') | |
| static_dir = Path('/data/') | |
| # mount FastAPI StaticFiles server | |
| app.mount("/static", StaticFiles(directory=static_dir), name="static") | |
| theme = gr.themes.Base( | |
| primary_hue=gr.themes.Color( | |
| c100="#dbeafe", c200="#bfdbfe", c300="#93c5fd", c400="#60a5fa", c50="#eff6ff", c500="#0054ae", c600="#00377c", c700="#00377c", c800="#1e40af", c900="#1e3a8a", c950="#0a0c2b"), | |
| secondary_hue=gr.themes.Color( | |
| c100="#dbeafe", c200="#bfdbfe", c300="#93c5fd", c400="#60a5fa", c50="#eff6ff", c500="#0054ae", c600="#0054ae", c700="#0054ae", c800="#1e40af", c900="#1e3a8a", c950="#1d3660"), | |
| ).set( | |
| body_background_fill_dark='*primary_950', | |
| body_text_color_dark='*neutral_300', | |
| border_color_accent='*primary_700', | |
| border_color_accent_dark='*neutral_800', | |
| block_background_fill_dark='*primary_950', | |
| block_border_width='2px', | |
| block_border_width_dark='2px', | |
| button_primary_background_fill_dark='*primary_500', | |
| button_primary_border_color_dark='*primary_500' | |
| ) | |
| css=''' | |
| @font-face { | |
| font-family: IntelOne; | |
| src: url("file/assets/intelone-bodytext-font-family-regular.ttf"); | |
| } | |
| ''' | |
| ## <td style="border-bottom:0"><img src="file/assets/DCAI_logo.png" height="300" width="300"></td> | |
| html_title = ''' | |
| <table> | |
| <tr style="height:150px"> | |
| <td style="border-bottom:0"><img src="file/assets/intel-labs.png" height="100" width="100"></td> | |
| <td style="border-bottom:0; vertical-align:bottom"> | |
| <p style="font-size:xx-large;font-family:IntelOne, Georgia, sans-serif;color: white;"> | |
| Cognitive AI: | |
| <br> | |
| Multimodal RAG on Videos | |
| </p> | |
| </td> | |
| <td style="border-bottom:0;"><img src="file/assets/gaudi.png" width="100" height="100"></td> | |
| <td style="border-bottom:0;"><img src="file/assets/xeon.png" width="100" height="100"></td> | |
| <td style="border-bottom:0;"><img src="file/assets/IDC7.png" width="400" height="350"></td> | |
| </tr> | |
| </table> | |
| ''' | |
| debug = False | |
| def print_debug(t): | |
| if debug: | |
| print(t) | |
| # https://stackoverflow.com/a/57781047 | |
| # Resizes a image and maintains aspect ratio | |
| # def maintain_aspect_ratio_resize(image, width=None, height=None, inter=cv2.INTER_AREA): | |
| # # Grab the image size and initialize dimensions | |
| # dim = None | |
| # (h, w) = image.shape[:2] | |
| # # Return original image if no need to resize | |
| # if width is None and height is None: | |
| # return image | |
| # # We are resizing height if width is none | |
| # if width is None: | |
| # # Calculate the ratio of the height and construct the dimensions | |
| # r = height / float(h) | |
| # dim = (int(w * r), height) | |
| # # We are resizing width if height is none | |
| # else: | |
| # # Calculate the ratio of the width and construct the dimensions | |
| # r = width / float(w) | |
| # dim = (width, int(h * r)) | |
| # # Return the resized image | |
| # return cv2.resize(image, dim, interpolation=inter) | |
| def time_to_frame(time, fps): | |
| ''' | |
| convert time in seconds into frame number | |
| ''' | |
| return int(time * fps - 1) | |
| def str2time(strtime): | |
| strtime = strtime.strip('"') | |
| hrs, mins, seconds = [float(c) for c in strtime.split(':')] | |
| total_seconds = hrs * 60**2 + mins * 60 + seconds | |
| return total_seconds | |
| def get_iframe(video_path: str, start: int = -1, end: int = -1): | |
| return f"""<video controls="controls" preload="metadata" src="{video_path}" width="540" height="310"></video>""" | |
| #TODO | |
| # def place(galleries, evt: gr.SelectData): | |
| # print(evt.value) | |
| # start_time = evt.value.split('||')[0].strip() | |
| # print(start_time) | |
| # # sub_video_id = evt.value.split('|')[-1] | |
| # if start_time in start_time_index_map.keys(): | |
| # sub_video_id = start_time_index_map[start_time] | |
| # else: | |
| # sub_video_id = 0 | |
| # path_to_sub_video = f"/static/video_embeddings/mp4.keynotes23/sub-videos/keynotes23_split{sub_video_id}.mp4" | |
| # # return evt.value | |
| # return get_iframe(path_to_sub_video) | |
| # def process(text_query): | |
| # tmp_dir = os.environ.get('VID_CACHE_DIR', os.environ.get('TMPDIR', './video_embeddings')) | |
| # frames, transcripts = run_query(text_query, path=tmp_dir) | |
| # # return video_file_path, [(image, caption) for image, caption in zip(frame_paths, transcripts)] | |
| # return [(frame, caption) for frame, caption in zip(frames, transcripts)], "" | |
| description = "This Space lets you engage with multimodal RAG on a video through a chat box." | |
| no_change_btn = gr.Button.update() | |
| enable_btn = gr.Button.update(interactive=True) | |
| disable_btn = gr.Button.update(interactive=False) | |
| # textbox = gr.Textbox( | |
| # show_label=False, placeholder="Enter text and press ENTER", container=False | |
| # ) | |
| def clear_history(request: gr.Request): | |
| logger.info(f"clear_history. ip: {request.client.host}") | |
| state = cur_conv.copy() | |
| return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 1 | |
| def add_text(state, text, request: gr.Request): | |
| logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}") | |
| if len(text) <= 0 : | |
| state.skip_next = True | |
| return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 1 | |
| text = text[:1536] # Hard cut-off | |
| state.append_message(state.roles[0], text) | |
| state.append_message(state.roles[1], None) | |
| state.skip_next = False | |
| return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 1 | |
| def http_bot( | |
| state, request: gr.Request | |
| ): | |
| logger.info(f"http_bot. ip: {request.client.host}") | |
| start_tstamp = time.time() | |
| if state.skip_next: | |
| # This generate call is skipped due to invalid inputs | |
| path_to_sub_videos = state.get_path_to_subvideos() | |
| yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (no_change_btn,) * 1 | |
| return | |
| if len(state.messages) == state.offset + 2: | |
| # First round of conversation | |
| new_state = cur_conv.copy() | |
| new_state.append_message(new_state.roles[0], state.messages[-2][1]) | |
| new_state.append_message(new_state.roles[1], None) | |
| state = new_state | |
| # Construct prompt | |
| prompt = state.get_prompt() | |
| all_images = state.get_images(return_pil=False) | |
| # Make requests | |
| is_very_first_query = True | |
| if len(all_images) == 0: | |
| # first query need to do RAG | |
| pload = { | |
| "query": prompt, | |
| } | |
| else: | |
| # subsequence queries, no need to do Retrieval | |
| is_very_first_query = False | |
| pload = { | |
| "prompt": prompt, | |
| "path-to-image": all_images[0], | |
| } | |
| if is_very_first_query: | |
| url = worker_addr + "/v1/rag/chat" | |
| else: | |
| url = worker_addr + "/v1/rag/multi_turn_chat" | |
| logger.info(f"==== request ====\n{pload}") | |
| logger.info(f"==== url request ====\n{url}") | |
| #uncomment this for testing UI only | |
| # state.messages[-1][-1] = f"response {len(state.messages)}" | |
| # yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 1 | |
| # return | |
| state.messages[-1][-1] = "▌" | |
| path_to_sub_videos = state.get_path_to_subvideos() | |
| yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (disable_btn,) * 1 | |
| try: | |
| # Stream output | |
| response = requests.post(url, headers=headers, json=pload, timeout=100, stream=True) | |
| for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): | |
| if chunk: | |
| res = json.loads(chunk.decode()) | |
| ## old_method | |
| # if response.status_code == 200: | |
| # cur_json = "" | |
| # for chunk in response: | |
| # # print('chunk is ---> ', chunk.decode('utf-8')) | |
| # cur_json += chunk.decode('utf-8') | |
| # try: | |
| # res = json.loads(cur_json) | |
| # except: | |
| # # a whole json does not include in this chunk, need to concatenate with next chunk | |
| # continue | |
| # # successfully load json into res | |
| # cur_json = "" | |
| if state.path_to_img is None and 'path-to-image' in res: | |
| state.path_to_img = res['path-to-image'] | |
| if state.video_title is None and 'title' in res: | |
| state.video_title = res['title'] | |
| if 'answer' in res: | |
| # print(f"answer is {res['answer']}") | |
| output = res["answer"] | |
| # print(f"state.messages is {state.messages[-1][-1]}") | |
| state.messages[-1][-1] = state.messages[-1][-1][:-1] + output + "▌" | |
| path_to_sub_videos = state.get_path_to_subvideos() | |
| yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (disable_btn,) * 1 | |
| time.sleep(0.03) | |
| # else: | |
| # raise requests.exceptions.RequestException() | |
| except requests.exceptions.RequestException as e: | |
| state.messages[-1][-1] = server_error_msg | |
| yield (state, state.to_gradio_chatbot(), None) + ( | |
| enable_btn, | |
| ) | |
| return | |
| state.messages[-1][-1] = state.messages[-1][-1][:-1] | |
| path_to_sub_videos = state.get_path_to_subvideos() | |
| logger.info(path_to_sub_videos) | |
| yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (enable_btn,) * 1 | |
| finish_tstamp = time.time() | |
| logger.info(f"{state.messages[-1][-1]}") | |
| # with open(get_conv_log_filename(), "a") as fout: | |
| # data = { | |
| # "tstamp": round(finish_tstamp, 4), | |
| # "url": url, | |
| # "start": round(start_tstamp, 4), | |
| # "finish": round(start_tstamp, 4), | |
| # "state": state.dict(), | |
| # } | |
| # fout.write(json.dumps(data) + "\n") | |
| return | |
| dropdown_list = [ | |
| "What did Intel present at Nasdaq?", | |
| "From Chips Act Funding Announcement, by which year is Intel committed to Net Zero gas emissions?", | |
| "What percentage of renewable energy is Intel planning to use?", | |
| "a band playing music", | |
| "Which US state is Silicon Desert referred to?", | |
| "and which US state is Silicon Forest referred to?", | |
| "How do trigate fins work?", | |
| "What is the advantage of trigate over planar transistors?", | |
| "What are key objectives of transistor design?", | |
| "How fast can transistors switch?", | |
| ] | |
| with gr.Blocks(theme=theme, css=css) as demo: | |
| # gr.Markdown(description) | |
| state = gr.State(default_conversation.copy()) | |
| gr.HTML(value=html_title) | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| video = gr.Video(height=512, width=512, elem_id="video" ) | |
| with gr.Column(scale=7): | |
| chatbot = gr.Chatbot( | |
| elem_id="chatbot", label="Multimodal RAG Chatbot", height=450 | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=8): | |
| # textbox.render() | |
| textbox = gr.Dropdown( | |
| dropdown_list, | |
| allow_custom_value=True, | |
| # show_label=False, | |
| # container=False, | |
| label="Query", | |
| info="Enter your query here or choose a sample from the dropdown list!" | |
| ) | |
| with gr.Column(scale=1, min_width=50): | |
| submit_btn = gr.Button( | |
| value="Send", variant="primary", interactive=True | |
| ) | |
| with gr.Row(elem_id="buttons") as button_row: | |
| clear_btn = gr.Button(value="🗑️ Clear history", interactive=False) | |
| # Register listeners | |
| btn_list = [clear_btn] | |
| clear_btn.click( | |
| clear_history, None, [state, chatbot, textbox, video] + btn_list | |
| ) | |
| # textbox.submit( | |
| # add_text, | |
| # [state, textbox], | |
| # [state, chatbot, textbox,] + btn_list, | |
| # ).then( | |
| # http_bot, | |
| # [state, ], | |
| # [state, chatbot, video] + btn_list, | |
| # ) | |
| submit_btn.click( | |
| add_text, | |
| [state, textbox], | |
| [state, chatbot, textbox,] + btn_list, | |
| ).then( | |
| http_bot, | |
| [state, ], | |
| [state, chatbot, video] + btn_list, | |
| ) | |
| print_debug('Beginning') | |
| # btn.click(fn=process, | |
| # inputs=[text_query], | |
| # # outputs=[video_player, gallery], | |
| # outputs=[gallery, html], | |
| # ) | |
| # gallery.select(place, [gallery], [html]) | |
| demo.queue() | |
| app = gr.mount_gradio_app(app, demo, path='/') | |
| share = False | |
| enable_queue = True | |
| # try: | |
| # demo.queue(concurrency_count=3)#, enable_queue=False) | |
| # demo.launch(enable_queue=enable_queue, share=share, server_port=17808, server_name='0.0.0.0') | |
| # #BATCH -w isl-gpu48 | |
| # except: | |
| # demo.launch(enable_queue=False, share=share, server_port=17808, server_name='0.0.0.0') | |
| # serve the app | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--host", type=str, default="0.0.0.0") | |
| parser.add_argument("--port", type=int, default=7899) | |
| parser.add_argument("--concurrency-count", type=int, default=20) | |
| parser.add_argument("--share", action="store_true") | |
| parser.add_argument("--worker-address", type=str, default="198.175.88.247") | |
| parser.add_argument("--worker-port", type=int, default=7899) | |
| args = parser.parse_args() | |
| logger.info(f"args: {args}") | |
| global worker_addr | |
| worker_addr = f"http://{args.worker_address}:{args.worker_port}" | |
| uvicorn.run(app, host=args.host, port=args.port) | |
| # for i in examples: | |
| # print(f'Processing {i[0]}') | |
| # results = process(*i) | |
| # print(f'{len(results[0])} results returned') | |