Spaces:
Running
Running
| import dataclasses | |
| from enum import auto, Enum | |
| from typing import List, Tuple | |
| import os | |
| class SeparatorStyle(Enum): | |
| """Different separator style.""" | |
| SINGLE = auto() | |
| TWO = auto() | |
| MPT = auto() | |
| PLAIN = auto() | |
| LLAMA_2 = auto() | |
| MISTRAL = auto() | |
| # video_helper_map = { | |
| # # 'Chips Making Deal Video' : {'path' : '/data/videos/ChipmakingDeal/sub-videos/', 'prefix' : 'ChipmakingDeal_split'}, | |
| # 'Keynote 2023' : {'path' : '/data/videos/PatsKeynote23/sub-videos/', 'prefix' : 'keynotes23_split'}, | |
| # 'Intel Behind the Bell' : {'path' : '/data/videos/BehindTheBell/sub-videos/', 'prefix' : 'Behind the Bell Intel_split'}, | |
| # 'CEOs Talk' : {'path' : '/data/videos/SamPatTalkAI/sub-videos/', 'prefix' : 'Sam Altman and Pat Gelsinger Talk Artificial Intelligence_split'}, | |
| # 'Chips Act Funding Announcement' : {'path' : '/data/videos/IntelChipsFundingAnnounce/sub-videos/', 'prefix' : 'Intel Celebrates CHIPS and Science Act Direct Funding Announcement (Replay)_split'}, | |
| # '22nm-Chip Technology' : {'path' : '/data/videos/MarkBohrExplains22nm/sub-videos/', 'prefix' : 'Video Animation Mark Bohr Gets Small 22nm Explained Intel_split'}, | |
| # '14nm-Chip Technology' : {'path' : '/data/videos/MarkBohrExplains14nm/sub-videos/', 'prefix' : 'Explanation of Intels 14nm Process_split'}, | |
| # } | |
| video_helper_map = { | |
| # 'Chips Making Deal Video' : {'path' : '/data/videos/ChipmakingDeal/sub-videos/', 'prefix' : 'ChipmakingDeal_split'}, | |
| 'Innovation-2023' : {'path' : '/data1/tile_gh/Multimodal-RAG/videos/PatsKeynote23/sub-videos/', 'prefix' : 'keynotes23_split'}, | |
| 'Behind-the-Bell-Intel' : {'path' : '/data1/tile_gh/Multimodal-RAG/videos/BehindTheBell/sub-videos/', 'prefix' : 'Behind the Bell Intel_split'}, | |
| 'Foundry-Connect' : {'path' : '/data1/tile_gh/Multimodal-RAG/videos/SamPatTalkAI/sub-videos/', 'prefix' : 'Sam Altman and Pat Gelsinger Talk Artificial Intelligence_split'}, | |
| 'Chips Act Funding Announcement' : {'path' : '/data1/tile_gh/Multimodal-RAG/videos/IntelChipsFundingAnnounce/sub-videos/', 'prefix' : 'Intel Celebrates CHIPS and Science Act Direct Funding Announcement (Replay)_split'}, | |
| '22nm-transistor-animation' : {'path' : '/data1/tile_gh/Multimodal-RAG/videos/MarkBohrExplains22nm/sub-videos/', 'prefix' : 'Video Animation Mark Bohr Gets Small 22nm Explained Intel_split'}, | |
| '14nm-transistor-animation' : {'path' : '/data1/tile_gh/Multimodal-RAG/videos/MarkBohrExplains14nm/sub-videos/', 'prefix' : 'Explanation of Intels 14nm Process_split'}, | |
| } | |
| class Conversation: | |
| """A class that keeps all conversation history.""" | |
| system: str | |
| roles: List[str] | |
| messages: List[List[str]] | |
| offset: int | |
| sep_style: SeparatorStyle = SeparatorStyle.SINGLE | |
| sep: str = "\n" | |
| sep2: str = None | |
| version: str = "Unknown" | |
| path_to_img: str = None | |
| video_title: str = None | |
| caption: str = None | |
| skip_next: bool = False | |
| def _template_caption(self): | |
| out = "" | |
| if self.caption is not None: | |
| out = f"The caption associated with the image is '{self.caption}'. " | |
| return out | |
| def get_prompt(self): | |
| messages = self.messages | |
| if len(messages) > 0 and messages[1][1] is not None and "<image>" not in messages[0][1]: | |
| # if there is a history message and <image> is not yet in the first message of user | |
| # then add <image>\n to the beginning | |
| messages = self.messages.copy() | |
| init_role, init_msg = messages[0].copy() | |
| messages[0] = (init_role, "<image>\n" + self._template_caption() + init_msg) | |
| if len(messages) > 1 and messages[1][1] is None: | |
| #Need to do RAG. prompt is the query only | |
| ret = messages[0][1] | |
| else: | |
| if self.sep_style == SeparatorStyle.SINGLE: | |
| ret = "" | |
| for role, message in messages: | |
| if message: | |
| ret += role + ": " + message + self.sep | |
| else: | |
| ret += role + ":" | |
| elif self.sep_style == SeparatorStyle.LLAMA_2: | |
| wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg | |
| wrap_inst = lambda msg: f"[INST] {msg} [/INST]" | |
| ret = "" | |
| for i, (role, message) in enumerate(messages): | |
| if i == 0: | |
| assert message, "first message should not be none" | |
| assert role == self.roles[0], "first message should come from user" | |
| if message: | |
| if type(message) is tuple: | |
| message, _, _ = message | |
| if i == 0: message = wrap_sys(self.system) + message | |
| if i % 2 == 0: | |
| message = wrap_inst(message) | |
| ret += self.sep + message | |
| else: | |
| ret += " " + message + " " + self.sep2 | |
| else: | |
| ret += "" | |
| ret = ret.lstrip(self.sep) | |
| else: | |
| raise ValueError(f"Invalid style: {self.sep_style}") | |
| return ret | |
| def append_message(self, role, message): | |
| self.messages.append([role, message]) | |
| def get_images(self, return_pil=False): | |
| images = [] | |
| if self.path_to_img is not None: | |
| path_to_image = self.path_to_img | |
| images.append(path_to_image) | |
| # import base64 | |
| # from io import BytesIO | |
| # from PIL import Image | |
| # image = Image.open(path_to_image) | |
| # max_hw, min_hw = max(image.size), min(image.size) | |
| # aspect_ratio = max_hw / min_hw | |
| # max_len, min_len = 800, 400 | |
| # shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) | |
| # longest_edge = int(shortest_edge * aspect_ratio) | |
| # W, H = image.size | |
| # if longest_edge != max(image.size): | |
| # if H > W: | |
| # H, W = longest_edge, shortest_edge | |
| # else: | |
| # H, W = shortest_edge, longest_edge | |
| # image = image.resize((W, H)) | |
| # if return_pil: | |
| # images.append(image) | |
| # else: | |
| # # buffered = BytesIO() | |
| # # # image.save(buffered, format="PNG") | |
| # # img_b64_str = base64.b64encode(buffered.getvalue()).decode() | |
| # images.append(path_to_image) | |
| return images | |
| def to_gradio_chatbot(self): | |
| ret = [] | |
| for i, (role, msg) in enumerate(self.messages[self.offset:]): | |
| if i % 2 == 0: | |
| if type(msg) is tuple: | |
| import base64 | |
| from io import BytesIO | |
| msg, image, image_process_mode = msg | |
| max_hw, min_hw = max(image.size), min(image.size) | |
| aspect_ratio = max_hw / min_hw | |
| max_len, min_len = 800, 400 | |
| shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) | |
| longest_edge = int(shortest_edge * aspect_ratio) | |
| W, H = image.size | |
| if H > W: | |
| H, W = longest_edge, shortest_edge | |
| else: | |
| H, W = shortest_edge, longest_edge | |
| image = image.resize((W, H)) | |
| buffered = BytesIO() | |
| image.save(buffered, format="JPEG") | |
| img_b64_str = base64.b64encode(buffered.getvalue()).decode() | |
| img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />' | |
| msg = img_str + msg.replace('<image>', '').strip() | |
| ret.append([msg, None]) | |
| else: | |
| ret.append([msg, None]) | |
| else: | |
| ret[-1][-1] = msg | |
| return ret | |
| def copy(self): | |
| return Conversation( | |
| system=self.system, | |
| roles=self.roles, | |
| messages=[[x, y] for x, y in self.messages], | |
| offset=self.offset, | |
| sep_style=self.sep_style, | |
| sep=self.sep, | |
| sep2=self.sep2, | |
| version=self.version,) | |
| def dict(self): | |
| return { | |
| "system": self.system, | |
| "roles": self.roles, | |
| "messages": self.messages, | |
| "offset": self.offset, | |
| "sep": self.sep, | |
| "sep2": self.sep2, | |
| "path_to_img": self.path_to_img, | |
| "video_title" : self.video_title, | |
| "caption" : self.caption, | |
| } | |
| def get_path_to_subvideos(self): | |
| print(f"self.video_title {self.video_title}") | |
| print(f"self.path_to_image {self.path_to_img}") | |
| return None | |
| if self.video_title is not None and self.path_to_img is not None: | |
| info = video_helper_map[self.video_title] | |
| path = info['path'] | |
| prefix = info['prefix'] | |
| vid_index = self.path_to_img.split('/')[-1] | |
| vid_index = vid_index.split('_')[-1] | |
| vid_index = vid_index.replace('.jpg', '') | |
| ret = f"{prefix}{vid_index}.mp4" | |
| ret = os.path.join(path, ret) | |
| return ret | |
| elif self.path_to_img is not None: | |
| return self.path_to_img | |
| return None | |
| multimodal_rag = Conversation( | |
| system="", | |
| roles=("USER", "ASSISTANT"), | |
| messages=(), | |
| offset=0, | |
| sep_style=SeparatorStyle.SINGLE, | |
| sep="\n", | |
| path_to_img=None, | |
| video_title=None, | |
| caption=None, | |
| ) | |
| conv_mistral_instruct = Conversation( | |
| system="", | |
| roles=("USER", "ASSISTANT"), | |
| version="llama_v2", | |
| messages=(), | |
| offset=0, | |
| sep_style=SeparatorStyle.LLAMA_2, | |
| sep="", | |
| sep2="</s>", | |
| path_to_img=None, | |
| video_title=None, | |
| caption=None, | |
| ) | |
| default_conversation = multimodal_rag | |
| conv_templates = { | |
| "default": multimodal_rag, | |
| "multimodal_rag" : multimodal_rag, | |
| "llavamed_rag" : conv_mistral_instruct, | |
| } | |
| if __name__ == "__main__": | |
| print(default_conversation.get_prompt()) | |