|
|
import gradio as gr |
|
|
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer |
|
|
from transformers.image_utils import load_image |
|
|
from threading import Thread |
|
|
import time |
|
|
import torch |
|
|
import spaces |
|
|
import cv2 |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from gradio_client import Client, handle_file |
|
|
import os |
|
|
|
|
|
|
|
|
hf_token = os.getenv("HF_TOKEN") |
|
|
tts_client = Client("dofbi/galsenai-xtts-v2-wolof-inference", hf_token=hf_token) |
|
|
|
|
|
def tts(text): |
|
|
"""Generate TTS using Gradio API client""" |
|
|
try: |
|
|
result = tts_client.predict( |
|
|
text=text, |
|
|
audio_reference=handle_file('https://github.com/gradio-app/gradio/raw/main/test/test_files/audio_sample.wav'), |
|
|
api_name="/predict" |
|
|
) |
|
|
print(f"TTS result: {result}") |
|
|
|
|
|
|
|
|
if isinstance(result, tuple): |
|
|
|
|
|
return result[0] if result else None |
|
|
elif isinstance(result, str): |
|
|
|
|
|
return result |
|
|
elif hasattr(result, 'name'): |
|
|
|
|
|
return result.name |
|
|
else: |
|
|
|
|
|
return result |
|
|
except Exception as e: |
|
|
print(f"TTS API Error: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
def progress_bar_html(label: str) -> str: |
|
|
""" |
|
|
Returns an HTML snippet for a thin progress bar with a label. |
|
|
The progress bar is styled as a dark animated bar. |
|
|
""" |
|
|
return f''' |
|
|
<div style="display: flex; align-items: center;"> |
|
|
<span style="margin-right: 10px; font-size: 14px;">{label}</span> |
|
|
<div style="width: 110px; height: 5px; background-color: #9370DB; border-radius: 2px; overflow: hidden;"> |
|
|
<div style="width: 100%; height: 100%; background-color: #4B0082; animation: loading 1.5s linear infinite;"></div> |
|
|
</div> |
|
|
</div> |
|
|
<style> |
|
|
@keyframes loading {{ |
|
|
0% {{ transform: translateX(-100%); }} |
|
|
100% {{ transform: translateX(100%); }} |
|
|
}} |
|
|
</style> |
|
|
''' |
|
|
|
|
|
def downsample_video(video_path): |
|
|
""" |
|
|
Downsamples the video to 10 evenly spaced frames. |
|
|
Each frame is converted to a PIL Image along with its timestamp. |
|
|
""" |
|
|
vidcap = cv2.VideoCapture(video_path) |
|
|
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
fps = vidcap.get(cv2.CAP_PROP_FPS) |
|
|
frames = [] |
|
|
if total_frames <= 0 or fps <= 0: |
|
|
vidcap.release() |
|
|
return frames |
|
|
|
|
|
frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int) |
|
|
for i in frame_indices: |
|
|
vidcap.set(cv2.CAP_PROP_POS_FRAMES, i) |
|
|
success, image = vidcap.read() |
|
|
if success: |
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
pil_image = Image.fromarray(image) |
|
|
timestamp = round(i / fps, 2) |
|
|
frames.append((pil_image, timestamp)) |
|
|
vidcap.release() |
|
|
return frames |
|
|
|
|
|
MODEL_ID = "yaya-sy/chvtr" |
|
|
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True, min_pixels=256*28*28, max_pixels=1280*28*28) |
|
|
model = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
|
|
MODEL_ID, |
|
|
trust_remote_code=True, |
|
|
torch_dtype=torch.bfloat16 |
|
|
).to("cuda").eval() |
|
|
|
|
|
@spaces.GPU |
|
|
def model_inference(input_dict, history): |
|
|
text = input_dict["text"] |
|
|
files = input_dict["files"] |
|
|
|
|
|
if text.strip().lower().startswith("@video-infer"): |
|
|
|
|
|
text = text[len("@video-infer"):].strip() |
|
|
if not files: |
|
|
raise gr.Error("Please upload a video file along with your @video-infer query.") |
|
|
|
|
|
video_path = files[0] |
|
|
frames = downsample_video(video_path) |
|
|
if not frames: |
|
|
raise gr.Error("Could not process video.") |
|
|
|
|
|
messages = [ |
|
|
|
|
|
{ |
|
|
"role": "user", |
|
|
"content": [{"type": "text", "text": text}] |
|
|
} |
|
|
] |
|
|
|
|
|
for image, timestamp in frames: |
|
|
messages[0]["content"].append({"type": "text", "text": f"Frame {timestamp}:"}) |
|
|
messages[0]["content"].append({"type": "image", "image": image}) |
|
|
|
|
|
video_images = [image for image, _ in frames] |
|
|
|
|
|
prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
inputs = processor( |
|
|
text=[prompt], |
|
|
images=video_images, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
).to("cuda") |
|
|
|
|
|
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) |
|
|
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=128, temperature=2.0, min_p=0.8) |
|
|
thread = Thread(target=model.generate, kwargs=generation_kwargs) |
|
|
thread.start() |
|
|
buffer = "" |
|
|
for new_text in streamer: |
|
|
buffer += new_text |
|
|
time.sleep(0.001) |
|
|
yield buffer |
|
|
return |
|
|
|
|
|
if len(files) > 1: |
|
|
images = [load_image(image) for image in files] |
|
|
elif len(files) == 1: |
|
|
images = [load_image(files[0])] |
|
|
else: |
|
|
images = [] |
|
|
|
|
|
if text == "" and not images: |
|
|
raise gr.Error("Please input a query and optionally image(s).") |
|
|
if text == "" and images: |
|
|
raise gr.Error("Please input a text query along with the image(s).") |
|
|
|
|
|
messages = [ |
|
|
|
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
*[{"type": "image", "image": image} for image in images], |
|
|
{"type": "text", "text": text}, |
|
|
], |
|
|
} |
|
|
] |
|
|
prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
inputs = processor( |
|
|
text=[prompt], |
|
|
images=images if images else None, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
).to("cuda") |
|
|
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) |
|
|
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=128, temperature=2.0, min_p=0.8) |
|
|
thread = Thread(target=model.generate, kwargs=generation_kwargs) |
|
|
thread.start() |
|
|
buffer = "" |
|
|
for new_text in streamer: |
|
|
buffer += new_text |
|
|
time.sleep(0.01) |
|
|
yield buffer |
|
|
|
|
|
|
|
|
|
|
|
audio_path = tts(buffer) |
|
|
return audio_path |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# Oolel") |
|
|
|
|
|
chatbot = gr.Chatbot() |
|
|
msg = gr.MultimodalTextbox( |
|
|
label="Your Request", |
|
|
file_types=["image", "video"], |
|
|
file_count="multiple" |
|
|
) |
|
|
audio_output = gr.Audio(label="Generated Speech") |
|
|
clear = gr.Button("Clear") |
|
|
|
|
|
def respond(message, chat_history): |
|
|
|
|
|
bot_message = "" |
|
|
chat_history.append([message["text"], ""]) |
|
|
|
|
|
|
|
|
for response in model_inference(message, chat_history): |
|
|
bot_message = response |
|
|
chat_history[-1][1] = bot_message |
|
|
yield "", chat_history, None |
|
|
|
|
|
|
|
|
try: |
|
|
if bot_message.strip(): |
|
|
audio_path = tts(bot_message) |
|
|
if audio_path: |
|
|
yield "", chat_history, audio_path |
|
|
else: |
|
|
print("TTS returned None or empty result") |
|
|
yield "", chat_history, None |
|
|
else: |
|
|
yield "", chat_history, None |
|
|
except Exception as e: |
|
|
print(f"TTS Error: {e}") |
|
|
yield "", chat_history, None |
|
|
|
|
|
msg.submit(respond, [msg, chatbot], [msg, chatbot, audio_output]) |
|
|
clear.click(lambda: ([], None), outputs=[chatbot, audio_output]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(debug=True) |