tchvc / app.py
yaya-sy's picture
Update app.py
f5be841 verified
import gradio as gr
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer
from transformers.image_utils import load_image
from threading import Thread
import time
import torch
import spaces
import cv2
import numpy as np
from PIL import Image
from gradio_client import Client, handle_file
import os
# Initialize TTS client with HF token
hf_token = os.getenv("HF_TOKEN") # Set your HF token as environment variable
tts_client = Client("dofbi/galsenai-xtts-v2-wolof-inference", hf_token=hf_token)
def tts(text):
"""Generate TTS using Gradio API client"""
try:
result = tts_client.predict(
text=text,
audio_reference=handle_file('https://github.com/gradio-app/gradio/raw/main/test/test_files/audio_sample.wav'),
api_name="/predict"
)
print(f"TTS result: {result}") # Debug print to see what's returned
# Handle different possible return formats
if isinstance(result, tuple):
# If result is a tuple, the audio file might be in the first element
return result[0] if result else None
elif isinstance(result, str):
# If result is a string (file path)
return result
elif hasattr(result, 'name'):
# If result is a file object with a name attribute
return result.name
else:
# Try to return the result as-is
return result
except Exception as e:
print(f"TTS API Error: {e}")
return None
def progress_bar_html(label: str) -> str:
"""
Returns an HTML snippet for a thin progress bar with a label.
The progress bar is styled as a dark animated bar.
"""
return f'''
<div style="display: flex; align-items: center;">
<span style="margin-right: 10px; font-size: 14px;">{label}</span>
<div style="width: 110px; height: 5px; background-color: #9370DB; border-radius: 2px; overflow: hidden;">
<div style="width: 100%; height: 100%; background-color: #4B0082; animation: loading 1.5s linear infinite;"></div>
</div>
</div>
<style>
@keyframes loading {{
0% {{ transform: translateX(-100%); }}
100% {{ transform: translateX(100%); }}
}}
</style>
'''
def downsample_video(video_path):
"""
Downsamples the video to 10 evenly spaced frames.
Each frame is converted to a PIL Image along with its timestamp.
"""
vidcap = cv2.VideoCapture(video_path)
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = vidcap.get(cv2.CAP_PROP_FPS)
frames = []
if total_frames <= 0 or fps <= 0:
vidcap.release()
return frames
# Sample 10 evenly spaced frames.
frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
for i in frame_indices:
vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
success, image = vidcap.read()
if success:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(image)
timestamp = round(i / fps, 2)
frames.append((pil_image, timestamp))
vidcap.release()
return frames
MODEL_ID = "yaya-sy/chvtr" # "kaamd/chtvctr" # Alternatively: "Qwen/Qwen2.5-VL-3B-Instruct"
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True, min_pixels=256*28*28, max_pixels=1280*28*28)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID,
trust_remote_code=True,
torch_dtype=torch.bfloat16
).to("cuda").eval()
@spaces.GPU
def model_inference(input_dict, history):
text = input_dict["text"]
files = input_dict["files"]
if text.strip().lower().startswith("@video-infer"):
# Remove the tag from the query.
text = text[len("@video-infer"):].strip()
if not files:
raise gr.Error("Please upload a video file along with your @video-infer query.") # Fixed: gr.Error syntax
# Assume the first file is a video.
video_path = files[0]
frames = downsample_video(video_path)
if not frames:
raise gr.Error("Could not process video.") # Fixed: gr.Error syntax
# Build messages: start with the text prompt.
messages = [
# {"role": "system", "content": "Answer clearly to the user's requesst. Please do not use numbers, only letters. If you want to answer with a number, convert it to letters. For example, you should not say 'am an 2 xaj' but 'am an Γ±aari xaj.'"},
{
"role": "user",
"content": [{"type": "text", "text": text}]
}
]
# Append each frame with a timestamp label.
for image, timestamp in frames:
messages[0]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
messages[0]["content"].append({"type": "image", "image": image})
# Collect only the images from the frames.
video_images = [image for image, _ in frames]
# Prepare the prompt.
prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(
text=[prompt],
images=video_images,
return_tensors="pt",
padding=True,
).to("cuda")
# Set up streaming generation.
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=128, temperature=2.0, min_p=0.8)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
time.sleep(0.001)
yield buffer
return # Fixed: Added return to prevent falling through
if len(files) > 1:
images = [load_image(image) for image in files]
elif len(files) == 1:
images = [load_image(files[0])]
else:
images = []
if text == "" and not images:
raise gr.Error("Please input a query and optionally image(s).") # Fixed: gr.Error syntax
if text == "" and images:
raise gr.Error("Please input a text query along with the image(s).") # Fixed: gr.Error syntax
messages = [
# {"role": "system", "content": "Answer clearly to the user's requesst. Please do not use numbers, only letters. If you want to answer with a number, convert it to letters. For example, you should not say 'am an 2 xaj' but 'am an Γ±aari xaj.'"},
{
"role": "user",
"content": [
*[{"type": "image", "image": image} for image in images],
{"type": "text", "text": text},
],
}
]
prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(
text=[prompt],
images=images if images else None,
return_tensors="pt",
padding=True,
).to("cuda")
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=128, temperature=2.0, min_p=0.8)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
time.sleep(0.01)
yield buffer
# This will only be reached after streaming is complete
# Generate TTS for the final buffer content
audio_path = tts(buffer)
return audio_path # Return the audio file path
# Option 1: Use regular Interface with streaming (recommended)
with gr.Blocks() as demo:
gr.Markdown("# Oolel")
chatbot = gr.Chatbot()
msg = gr.MultimodalTextbox(
label="Your Request",
file_types=["image", "video"],
file_count="multiple"
)
audio_output = gr.Audio(label="Generated Speech")
clear = gr.Button("Clear")
def respond(message, chat_history):
# Add user message to chat history
bot_message = ""
chat_history.append([message["text"], ""])
# Stream the response
for response in model_inference(message, chat_history):
bot_message = response
chat_history[-1][1] = bot_message
yield "", chat_history, None
# Generate audio after streaming is complete
try:
if bot_message.strip(): # Only generate TTS if there's actual text
audio_path = tts(bot_message)
if audio_path:
yield "", chat_history, audio_path
else:
print("TTS returned None or empty result")
yield "", chat_history, None
else:
yield "", chat_history, None
except Exception as e:
print(f"TTS Error: {e}")
yield "", chat_history, None
msg.submit(respond, [msg, chatbot], [msg, chatbot, audio_output])
clear.click(lambda: ([], None), outputs=[chatbot, audio_output])
# Option 2: Use ChatInterface without outputs parameter (simpler but no audio)
# demo = gr.ChatInterface(
# fn=model_inference,
# description="# oolel-vision-experimental `@video-infer for video understanding`**",
# fill_height=True,
# textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple"),
# stop_btn="Stop Generation",
# multimodal=True,
# cache_examples=False,
# )
if __name__ == "__main__":
demo.launch(debug=True)