Spaces:
Sleeping
Sleeping
| import os | |
| import argparse | |
| import traceback | |
| import logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S') | |
| logging.getLogger("http").setLevel(logging.WARNING) | |
| logging.getLogger("httpx").setLevel(logging.WARNING) | |
| import spaces | |
| import gradio as gr | |
| from conversation_public import default_conversation | |
| auth_token = os.environ.get("TOKEN_FROM_SECRET") | |
| ########################################## | |
| # LLM part | |
| ########################################## | |
| import torch | |
| from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM | |
| from transformers import Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer | |
| from qwen_vl_utils import process_vision_info | |
| from threading import Thread | |
| # === Prompts === | |
| SYSTEM_PROMPT_LLM = "You are a helpful assistant." | |
| SYSTEM_PROMPT_CAP = "You are given an image and a relevant question. Based on the query, please describe the image in details. Do not try to answer the question." | |
| CAPTION_PROMPT = "Question: {}\nPlease describe the image. DO NOT try to answer the question!" | |
| LLM_PROMPT = """In the following text, you will receive a detailed caption of an image and a relevant question. In addition, you will be provided with a tentative model response. You goal is to answer the question using these information.\n\n### The detailed caption of the provided image: {}\n\n### Note that the caption might contain incorrect solutions, do not be misguided by them.\n\n### A problem to be solved: {}\n\n### A tentative model response: {}\n\n### Note that the above tentative response might be inaccurate (due to calculation errors, incorrect logic/reasoning and so on), under such a case, please ignore it and give your own solutions. However, if you do not have enough evidence to show it is wrong, please output the tentative response.""" | |
| # === Initialize Models === | |
| MLLM_MODEL_PATH = "KaiChen1998/RACRO-7B-CRO-GRPO" | |
| LLM_MODEL_PATH = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B" | |
| processor = AutoProcessor.from_pretrained(MLLM_MODEL_PATH) | |
| tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_PATH) | |
| mllm = Qwen2_5_VLForConditionalGeneration.from_pretrained(MLLM_MODEL_PATH, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", device_map="auto") | |
| llm = AutoModelForCausalLM.from_pretrained(LLM_MODEL_PATH, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", device_map="auto") | |
| mllm_sampling = dict(do_sample=False, temperature=0, max_new_tokens=8192) | |
| llm_sampling = dict(temperature=0.6, top_p=0.95, max_new_tokens=8192) | |
| # === Build Prompts === | |
| def build_messages(image_path, question): | |
| cap_msgs = [ | |
| {"role": "system", "content": SYSTEM_PROMPT_CAP}, | |
| {"role": "user", "content": [{"type": "image", "image": image_path}, {"type": "text", "text": CAPTION_PROMPT.format(question)}]} | |
| ] | |
| qa_msgs = [ | |
| {"role": "user", "content": [{"type": "image", "image": image_path}, {"type": "text", "text": question + " Please think step by step. The final answer MUST BE put in \\boxed{}."}]} | |
| ] | |
| return cap_msgs, qa_msgs | |
| ########################################## | |
| # Streaming | |
| ########################################## | |
| mllm_streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15) | |
| llm_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15) | |
| def stream_response(model, inputs, streamer, prompt, gen_kwargs): | |
| thread = Thread(target=model.generate, kwargs=dict( | |
| streamer=streamer, | |
| **inputs, | |
| **gen_kwargs | |
| ) | |
| ) | |
| thread.start() | |
| generated_text = prompt | |
| for new_text in streamer: | |
| generated_text += new_text | |
| yield generated_text | |
| ########################################## | |
| # Gradio part | |
| ########################################## | |
| no_change_btn = gr.Button() | |
| enable_btn = gr.Button(interactive=True) | |
| disable_btn = gr.Button(interactive=False) | |
| server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" | |
| server_oom_msg = "**OUT OF GPU MEMORY DETECTED. PLEASE DECREASE THE MAX OUTPUT TOKENS AND REGENERATE.**" | |
| def load_demo_refresh_model_list(): | |
| logging.info(f"load_demo.") | |
| state = default_conversation.copy() | |
| return state | |
| def regenerate(state, image_process_mode): | |
| logging.info(f"regenerate.") | |
| state.messages = state.messages[:-3] | |
| prev_human_msg = state.messages[-1] | |
| if type(prev_human_msg[1]) in (tuple, list): | |
| prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode, *prev_human_msg[1][3:]) | |
| state.skip_next = False | |
| return (state, state.to_gradio_chatbot_public(), "", None) + (disable_btn,) * 2 | |
| def clear_history(): | |
| logging.info(f"clear_history.") | |
| state = default_conversation.copy() | |
| return (state, state.to_gradio_chatbot_public(), "", None) + (disable_btn,) * 2 | |
| ############ | |
| # Show prompt in the chatbot | |
| # Input: [state, textbox, imagebox, image_process_mode] | |
| # Return: [state, chatbot, textbox, imagebox] + btn_list | |
| ############ | |
| def add_text(state, text, image, image_process_mode): | |
| # Input legality checking | |
| logging.info(f"add_text. len: {len(text)}") | |
| if len(text) <= 0 or image is None: | |
| state.skip_next = True | |
| return (state, state.to_gradio_chatbot_public(), "", None) + (no_change_btn,) * 2 | |
| # Deal with image inputs | |
| if image is not None: | |
| text = (text, image, image_process_mode, None) | |
| # Single round only | |
| state = default_conversation.copy() | |
| state.append_message(state.roles[0], text) | |
| state.skip_next = False | |
| logging.info(str(state.messages)) | |
| return (state, state.to_gradio_chatbot_public(), "", None) + (disable_btn,) * 2 | |
| ############ | |
| # Get response | |
| # Input: [state] | |
| # Return: [state, chatbot] + btn_list | |
| ############ | |
| def http_bot(state): | |
| logging.info(f"http_bot.") | |
| if state.skip_next: | |
| yield (state, state.to_gradio_chatbot_public()) + (no_change_btn,) * 2 | |
| return | |
| # Retrive prompt | |
| prompt = state.messages[-1][-1][0] | |
| all_images = state.get_images(return_pil=True)[0] | |
| pload = {"prompt": prompt, "images": f'List of {len(state.get_images())} images: {all_images}'} | |
| logging.info(f"==== request ====\n{pload}") | |
| # Construct prompt | |
| cap_msgs, qa_msgs = build_messages(all_images, prompt) | |
| cap_prompt = processor.apply_chat_template(cap_msgs, tokenize=False, add_generation_prompt=True) | |
| qa_prompt = processor.apply_chat_template(qa_msgs, tokenize=False, add_generation_prompt=True) | |
| image_tensor, _ = process_vision_info(cap_msgs) | |
| cap_inputs = processor(text=cap_prompt, images=image_tensor, return_tensors="pt").to(mllm.device) | |
| qa_inputs = processor(text=qa_prompt, images=image_tensor, return_tensors="pt").to(mllm.device) | |
| # Step 1: Tentative Response | |
| state.append_message(state.roles[1], "# Tentative Response\n\n▌") | |
| try: | |
| for generated_text in stream_response(mllm, qa_inputs, mllm_streamer, qa_prompt, mllm_sampling): | |
| output = generated_text[len(qa_prompt):].strip() | |
| state.messages[-1][-1] = "# Tentative Response\n\n" + output + "▌" | |
| yield (state, state.to_gradio_chatbot_public()) + (disable_btn,) * 2 | |
| except Exception as e: | |
| os.system("nvidia-smi") | |
| logging.info(traceback.print_exc()) | |
| state.messages[-1][-1] = server_error_msg | |
| yield (state, state.to_gradio_chatbot_public()) + (enable_btn,) * 2 | |
| return | |
| tentative_answer = output | |
| logging.info(f"Tentative Response: {tentative_answer}") | |
| state.messages[-1][-1] = state.messages[-1][-1][:-1] | |
| yield (state, state.to_gradio_chatbot_public()) + (disable_btn,) * 2 | |
| # Step 2: Query-conditioned Caption | |
| state.append_message(state.roles[1], "# Query-conditioned Caption\n\n▌") | |
| try: | |
| for generated_text in stream_response(mllm, cap_inputs, mllm_streamer, cap_prompt, mllm_sampling): | |
| output = generated_text[len(cap_prompt):].strip() | |
| state.messages[-1][-1] = "# Query-conditioned Caption\n\n" + output + "▌" | |
| yield (state, state.to_gradio_chatbot_public()) + (disable_btn,) * 2 | |
| except Exception as e: | |
| os.system("nvidia-smi") | |
| logging.info(traceback.print_exc()) | |
| state.messages[-1][-1] = server_error_msg | |
| yield (state, state.to_gradio_chatbot_public()) + (enable_btn,) * 2 | |
| return | |
| caption_text = output | |
| logging.info(f"Query-conditioned Caption: {caption_text}") | |
| state.messages[-1][-1] = state.messages[-1][-1][:-1] | |
| yield (state, state.to_gradio_chatbot_public()) + (disable_btn,) * 2 | |
| # Step 3: Text-only Reasoning | |
| reason_msgs = [ | |
| {"role": "system", "content": SYSTEM_PROMPT_LLM}, | |
| {"role": "user", "content": LLM_PROMPT.format(caption_text, prompt, tentative_answer)} | |
| ] | |
| reason_prompt = tokenizer.apply_chat_template(reason_msgs, tokenize=False, add_generation_prompt=True) | |
| reason_inputs = tokenizer(reason_prompt, return_tensors="pt").to(llm.device) | |
| state.append_message(state.roles[1], "# Text-only Reasoning\n\n▌") | |
| try: | |
| for generated_text in stream_response(llm, reason_inputs, llm_streamer, reason_prompt, llm_sampling): | |
| output = generated_text[len(reason_prompt):].strip() | |
| state.messages[-1][-1] = "# Text-only Reasoning\n\n" + output + "▌" | |
| yield (state, state.to_gradio_chatbot_public()) + (disable_btn,) * 2 | |
| except Exception as e: | |
| os.system("nvidia-smi") | |
| logging.info(traceback.print_exc()) | |
| state.messages[-1][-1] = server_error_msg | |
| yield (state, state.to_gradio_chatbot_public()) + (enable_btn,) * 2 | |
| return | |
| final_response = output | |
| logging.info(f"Text-only Reasoning: {final_response}") | |
| state.messages[-1][-1] = state.messages[-1][-1][:-1] | |
| yield (state, state.to_gradio_chatbot_public()) + (enable_btn,) * 2 | |
| ############ | |
| # Layout Markdown | |
| ############ | |
| title_markdown = (""" | |
| <div style="display: flex; align-items: center; padding: 20px; border-radius: 10px; background-color: #f0f0f0;"> | |
| <div> | |
| <h1 style="margin: 0;">RACRO: Perceptual Decoupling for Scalable Multi-modal Reasoning via Reward-Optimized Captioning</h1> | |
| <h2 style="margin: 10px 0;">📃 <a href="https://www.arxiv.org/abs/2506.04559" style="font-weight: 400;">Paper</a> | 💻 <a href="https://github.com/gyhdog99/RACRO2" style="font-weight: 400;">Code</a> | 🤗 <a href="https://huggingface.co/collections/KaiChen1998/racro-6848ec8c65b3a0bf33d0fbdb" style="font-weight: 400;">HuggingFace</a></h2> | |
| <p style="margin: 20px 0;"> | |
| <strong>1. RACRO is designed for multi-modal reasoning, and thus, image inputs are <mark>ALWAYS</mark> necessary!</strong> | |
| </p> | |
| </div> | |
| </div> | |
| """) | |
| learn_more_markdown = (""" | |
| ## Citation | |
| <pre><code>@article{gou2025perceptual, | |
| author = {Gou, Yunhao and Chen, Kai and Liu, Zhili and Hong, Lanqing and Jin, Xin and Li, Zhenguo and Kwok, James T. and Zhang, Yu}, | |
| title = {Perceptual Decoupling for Scalable Multi-modal Reasoning via Reward-Optimized Captioning}, | |
| journal = {arXiv preprint arXiv:2506.04559}, | |
| year = {2025}, | |
| }</code></pre> | |
| """) | |
| block_css = """ | |
| #buttons button { | |
| min-width: min(120px,100%); | |
| } | |
| .message-row img { | |
| margin: 0px !important; | |
| } | |
| .avatar-container img { | |
| padding: 0px !important; | |
| } | |
| """ | |
| ############ | |
| # Layout Demo | |
| ############ | |
| def build_demo(embed_mode): | |
| textbox = gr.Textbox(label="Text", show_label=False, placeholder="Enter text and then click 💬 Chat to talk with me ^v^", container=False) | |
| with gr.Blocks(title="RACRO", theme=gr.themes.Default(), css=block_css) as demo: | |
| state = gr.State() | |
| if not embed_mode: | |
| gr.HTML(title_markdown) | |
| ############## | |
| # Chatbot | |
| ############## | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1): | |
| imagebox = gr.Image(type="pil", label="Image") | |
| image_process_mode = gr.Radio( | |
| ["Crop", "Resize", "Pad", "Default"], | |
| value="Default", | |
| label="Preprocess for non-square image", visible=False) | |
| gr.Examples(examples=[ | |
| ["./examples/image-text/demo_example.jpg", "When the canister is momentarily stopped by the spring, by what distance $d$ is the spring compressed?"], | |
| ], inputs=[imagebox, textbox], label='Examples') | |
| with gr.Column(scale=8): | |
| chatbot = gr.Chatbot( | |
| type="messages", | |
| elem_id="chatbot", | |
| label="RACRO Chatbot", | |
| layout="bubble", | |
| avatar_images=["examples/user_avator.png", "examples/icon_256.png"] | |
| ) | |
| textbox.render() | |
| with gr.Row(elem_id="buttons") as button_row: | |
| submit_btn = gr.Button(value="💬 Chat", variant="primary") | |
| # stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False) | |
| regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) | |
| clear_btn = gr.Button(value="🗑️ Clear", interactive=False) | |
| if not embed_mode: | |
| gr.Markdown(learn_more_markdown) | |
| # Register listeners | |
| btn_list = [regenerate_btn, clear_btn] | |
| regenerate_btn.click( | |
| regenerate, | |
| [state, image_process_mode], | |
| [state, chatbot, textbox, imagebox] + btn_list | |
| ).then( | |
| http_bot, | |
| [state], | |
| [state, chatbot] + btn_list, | |
| ) | |
| clear_btn.click( | |
| clear_history, | |
| None, | |
| [state, chatbot, textbox, imagebox] + btn_list, | |
| queue=False | |
| ) | |
| # probably mean press enter | |
| textbox.submit( | |
| add_text, | |
| [state, textbox, imagebox, image_process_mode], | |
| [state, chatbot, textbox, imagebox] + btn_list, | |
| queue=False | |
| ).then( | |
| http_bot, | |
| [state], | |
| [state, chatbot] + btn_list, | |
| ) | |
| submit_btn.click( | |
| add_text, | |
| [state, textbox, imagebox, image_process_mode], | |
| [state, chatbot, textbox, imagebox] + btn_list | |
| ).then( | |
| http_bot, | |
| [state], | |
| [state, chatbot] + btn_list, | |
| ) | |
| ############## | |
| # Demo loading | |
| ############## | |
| demo.load( | |
| load_demo_refresh_model_list, | |
| None, | |
| [state], | |
| queue=False | |
| ) | |
| return demo | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--share", action="store_true") | |
| parser.add_argument("--embed", action="store_true") | |
| args = parser.parse_args() | |
| demo = build_demo(args.embed) | |
| demo.queue( | |
| max_size=10, | |
| api_open=False | |
| ).launch( | |
| favicon_path="./examples/icon_256.png", | |
| allowed_paths=["/"], | |
| share=args.share | |
| ) | |