Kemo_Chat / app.py
kemo2003's picture
Update app.py
7d2dfb8 verified
# coding: utf-8
import gradio as gr
import torch
import numpy as np
from PIL import Image
import fitz # PyMuPDF
import pandas as pd
from huggingface_hub import login
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
# AutoProcessor, # Replaced by ViltProcessor for VQA
# AutoModelForVision2Seq, # Replaced by ViltForQuestionAnswering for VQA
WhisperProcessor,
WhisperForConditionalGeneration,
ViltProcessor, # Added for ViLT VQA model
ViltForQuestionAnswering # Added for ViLT VQA model
)
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoFeatureExtractor
import os
import scipy.io.wavfile as wavfile
import io
# --- Configuration & Model Loading ---
# Hugging Face Hub Login
token = os.getenv("HF_API_TOKEN")
if token:
login(token=token)
else:
print("تحذير: لم يتم تعيين متغير البيئة HF_API_TOKEN. بعض النماذج قد تتطلب تسجيل الدخول.")
# Device Configuration
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"استخدام الجهاز: {device}")
# 1. Text Generation Model: distilgpt2 (Lightweight)
text_model_name = "distilgpt2"
print(f"تحميل نموذج النص: {text_model_name}")
text_tokenizer = AutoTokenizer.from_pretrained(text_model_name)
text_model = AutoModelForCausalLM.from_pretrained(
text_model_name,
torch_dtype=torch.float32,
device_map="auto"
)
if text_tokenizer.pad_token is None:
text_tokenizer.pad_token = text_tokenizer.eos_token
print("تم تحميل نموذج النص.")
# 2. Image Analysis Model: dandelin/vilt-b32-finetuned-vqa (Lightweight, Public VQA)
image_model_name = "dandelin/vilt-b32-finetuned-vqa"
print(f"تحميل نموذج الصور (VQA): {image_model_name}")
image_processor = ViltProcessor.from_pretrained(image_model_name)
image_model = ViltForQuestionAnswering.from_pretrained(
image_model_name,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
device_map="auto"
)
print("تم تحميل نموذج الصور (VQA).")
# 3. Speech-to-Text Model: openai/whisper-tiny (Lightweight)
stt_model_name = "openai/whisper-tiny"
print(f"تحميل نموذج تحويل الكلام إلى نص: {stt_model_name}")
stt_processor = WhisperProcessor.from_pretrained(stt_model_name)
stt_model = WhisperForConditionalGeneration.from_pretrained(stt_model_name).to(device)
stt_model.config.forced_decoder_ids = None
print("تم تحميل نموذج تحويل الكلام إلى نص.")
# 4. Text-to-Speech Model: parler-tts/parler-tts-tiny-v1 (Lightweight)
tts_model_repo_id = "parler-tts/parler-tts-tiny-v1"
print(f"تحميل نموذج تحويل النص إلى كلام: {tts_model_repo_id}")
tts_model = ParlerTTSForConditionalGeneration.from_pretrained(tts_model_repo_id).to(device)
tts_feature_extractor = AutoFeatureExtractor.from_pretrained(tts_model_repo_id)
print("تم تحميل مكونات تحويل النص إلى كلام.")
# --- Helper Functions for Model Inference ---
# 1. Text Generation (using distilgpt2)
def generate_text_response(prompt_text):
try:
full_prompt = f"السؤال: {prompt_text}\nالإجابة الودية والواضحة:"
inputs = text_tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True, max_length=512).to(text_model.device)
outputs = text_model.generate(
**inputs,
max_new_tokens=150,
temperature=0.7,
top_k=50,
do_sample=True,
pad_token_id=text_tokenizer.eos_token_id,
no_repeat_ngram_size=2
)
response_text = text_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
return response_text.strip()
except Exception as e:
print(f"خطأ في توليد النص: {e}")
return f"خطأ في معالجة النص: {str(e)}"
# 2. Image Analysis (using dandelin/vilt-b32-finetuned-vqa)
def analyze_image(pil_image, question_text=None):
try:
if pil_image is None:
return "الرجاء رفع صورة أولاً."
if isinstance(pil_image, np.ndarray):
pil_image = Image.fromarray(pil_image).convert("RGB")
else:
pil_image = pil_image.convert("RGB")
if not question_text or question_text.strip() == "":
# ViLT is a VQA model, it needs a question.
# If no question, we can ask a generic one, or return a message.
# For now, let's ask a generic question if none is provided.
question_text = "What is in this image?"
# Prepare inputs for ViLT
encoding = image_processor(pil_image, question_text, return_tensors="pt").to(image_model.device)
# Forward pass
with torch.no_grad():
outputs = image_model(**encoding)
logits = outputs.logits
idx = logits.argmax(-1).item()
response_text = image_model.config.id2label[idx]
return response_text
except Exception as e:
print(f"خطأ في تحليل الصورة: {e}")
return f"خطأ في تحليل الصورة: {str(e)}"
# 3. Audio Processing (STT with Whisper Tiny and TTS with ParlerTTS Tiny)
def process_audio(audio_input):
try:
if audio_input is None:
return "الرجاء تسجيل الصوت أولاً.", "", (16000, np.array([], dtype=np.int16))
sample_rate, audio_data = audio_input
if audio_data.dtype != np.float32:
audio_data = audio_data.astype(np.float32)
if np.max(np.abs(audio_data)) > 0:
audio_data = audio_data / np.max(np.abs(audio_data))
else:
return "تم استقبال صوت صامت.", "", (16000, np.array([], dtype=np.int16))
input_features = stt_processor(audio_data, sampling_rate=sample_rate, return_tensors="pt").input_features.to(device)
predicted_ids = stt_model.generate(input_features, language="ar")
transcription = stt_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0].strip()
if not transcription:
return "لم يتمكن النموذج من استخراج نص من الصوت.", "", (16000, np.array([], dtype=np.int16))
text_response = generate_text_response(transcription)
prompt = text_response
with torch.no_grad():
generation_output = tts_model.generate(input_ids=None,
prompt=prompt,
do_sample=True,
temperature=1.0).cpu().numpy().squeeze()
audio_output_np = generation_output
tts_sample_rate = tts_model.config.sampling_rate
return transcription, text_response, (tts_sample_rate, audio_output_np)
except Exception as e:
print(f"خطأ في معالجة الصوت: {e}")
empty_audio_data = np.array([], dtype=np.float32)
return f"خطأ في معالجة الصوت: {str(e)}", "", (16000, empty_audio_data)
# 4. File Processing
def process_file(file_obj):
try:
if file_obj is None:
return "الرجاء رفع ملف أولاً."
file_path = file_obj.name
text_content = ""
if file_path.endswith(".pdf"):
with fitz.open(file_path) as doc:
text_content = "\n".join(page.get_text() for page in doc)
elif file_path.endswith((".xlsx", ".xls")):
df = pd.read_excel(file_path)
text_content = df.to_string()
elif file_path.endswith(".csv"):
df = pd.read_csv(file_path)
text_content = df.to_string()
else:
return "❌ نوع الملف غير مدعوم حالياً (يدعم PDF, Excel, CSV)."
if not text_content.strip():
return "الملف فارغ أو لا يمكن قراءة محتواه النصي."
max_context_len = 1000
if len(text_content) > max_context_len:
text_content = text_content[:max_context_len] + "... [المحتوى تم اختصاره]"
response = generate_text_response(f"لخص المحتوى التالي من الملف: \n\n{text_content}")
return response
except Exception as e:
print(f"خطأ في معالجة الملف: {e}")
return f"خطأ في قراءة الملف: {str(e)}"
# --- Gradio Interface ---
with gr.Blocks(css=".gradio-container {background-color: #f0f4f8; font-family: Arial; color: #333; padding: 20px;}", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🤖 Kemo Chat V3.2 - مساعد ذكي متعدد الوسائط (نماذج خفيفة الوزن - ViLT VQA)")
gr.Markdown("🎯 تفاعل معي عبر النصوص، الصور، الصوت أو الملفات! (باستخدام نماذج أقل استهلاكًا للذاكرة).")
gr.Markdown("📁 يدعم الملفات: PDF، Excel، CSV\n🖼️ يدعم الإجابة على الأسئلة حول الصور (VQA)\n🎙️ تحويل الصوت إلى نص والرد صوتياً")
with gr.Tab("💬 المحادثة النصية"):
text_input = gr.Textbox(label="اكتب سؤالك أو رسالتك هنا", lines=3)
text_output = gr.Textbox(label="الرد", lines=5, interactive=False)
text_submit = gr.Button("إرسال", variant="primary")
text_submit.click(fn=generate_text_response, inputs=text_input, outputs=text_output)
with gr.Tab("🖼️ تحليل الصور (سؤال وجواب)"):
gr.Markdown("ارفع صورة واطرح سؤالاً عنها.")
with gr.Row():
image_input = gr.Image(type="pil", label="ارفع صورة")
with gr.Column():
image_question = gr.Textbox(label="اطرح سؤالاً عن الصورة (مطلوب لـ ViLT)", lines=2, placeholder="مثال: What color is the car?")
image_output = gr.Textbox(label="الإجابة", lines=5, interactive=False)
image_submit = gr.Button("تحليل الصورة", variant="primary")
image_submit.click(fn=analyze_image, inputs=[image_input, image_question], outputs=image_output)
with gr.Tab("🎤 التفاعل الصوتي"):
gr.Markdown("سجّل رسالة صوتية، سيتم تحويلها إلى نص، ثم الرد عليها نصيًا وصوتيًا.")
audio_input = gr.Audio(sources=["microphone"], type="numpy", label="سجّل رسالتك الصوتية")
with gr.Row():
audio_transcription = gr.Textbox(label="النص المستخرج من صوتك", interactive=False, lines=2)
audio_text_response = gr.Textbox(label="الرد النصي على رسالتك", interactive=False, lines=3)
audio_output_player = gr.Audio(label="الرد الصوتي من المساعد", type="numpy", interactive=False)
audio_submit = gr.Button("معالجة الصوت", variant="primary")
audio_submit.click(fn=process_audio,
inputs=audio_input,
outputs=[audio_transcription, audio_text_response, audio_output_player])
with gr.Tab("📄 تحليل الملفات"):
gr.Markdown("ارفع ملف (PDF, Excel, CSV) وسأقوم بتلخيص محتواه أو الإجابة على أسئلتك حوله.")
file_input = gr.File(label="ارفع ملفك (PDF, Excel, CSV)", file_types=[".pdf", ".xls", ".xlsx", ".csv"])
file_output = gr.Textbox(label="الرد على محتوى الملف", lines=5, interactive=False)
file_submit = gr.Button("تحليل الملف", variant="primary")
file_submit.click(fn=process_file, inputs=file_input, outputs=file_output)
if __name__ == "__main__":
print("Launching Gradio Demo (Lightweight Models with ViLT VQA)...")
demo.launch(share=True)