|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import fitz |
|
|
import pandas as pd |
|
|
from huggingface_hub import login |
|
|
from transformers import ( |
|
|
AutoTokenizer, |
|
|
AutoModelForCausalLM, |
|
|
|
|
|
|
|
|
WhisperProcessor, |
|
|
WhisperForConditionalGeneration, |
|
|
ViltProcessor, |
|
|
ViltForQuestionAnswering |
|
|
) |
|
|
from parler_tts import ParlerTTSForConditionalGeneration |
|
|
from transformers import AutoFeatureExtractor |
|
|
|
|
|
import os |
|
|
import scipy.io.wavfile as wavfile |
|
|
import io |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
token = os.getenv("HF_API_TOKEN") |
|
|
if token: |
|
|
login(token=token) |
|
|
else: |
|
|
print("تحذير: لم يتم تعيين متغير البيئة HF_API_TOKEN. بعض النماذج قد تتطلب تسجيل الدخول.") |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"استخدام الجهاز: {device}") |
|
|
|
|
|
|
|
|
text_model_name = "distilgpt2" |
|
|
print(f"تحميل نموذج النص: {text_model_name}") |
|
|
text_tokenizer = AutoTokenizer.from_pretrained(text_model_name) |
|
|
text_model = AutoModelForCausalLM.from_pretrained( |
|
|
text_model_name, |
|
|
torch_dtype=torch.float32, |
|
|
device_map="auto" |
|
|
) |
|
|
if text_tokenizer.pad_token is None: |
|
|
text_tokenizer.pad_token = text_tokenizer.eos_token |
|
|
print("تم تحميل نموذج النص.") |
|
|
|
|
|
|
|
|
image_model_name = "dandelin/vilt-b32-finetuned-vqa" |
|
|
print(f"تحميل نموذج الصور (VQA): {image_model_name}") |
|
|
image_processor = ViltProcessor.from_pretrained(image_model_name) |
|
|
image_model = ViltForQuestionAnswering.from_pretrained( |
|
|
image_model_name, |
|
|
torch_dtype=torch.float16 if device == "cuda" else torch.float32, |
|
|
device_map="auto" |
|
|
) |
|
|
print("تم تحميل نموذج الصور (VQA).") |
|
|
|
|
|
|
|
|
stt_model_name = "openai/whisper-tiny" |
|
|
print(f"تحميل نموذج تحويل الكلام إلى نص: {stt_model_name}") |
|
|
stt_processor = WhisperProcessor.from_pretrained(stt_model_name) |
|
|
stt_model = WhisperForConditionalGeneration.from_pretrained(stt_model_name).to(device) |
|
|
stt_model.config.forced_decoder_ids = None |
|
|
print("تم تحميل نموذج تحويل الكلام إلى نص.") |
|
|
|
|
|
|
|
|
tts_model_repo_id = "parler-tts/parler-tts-tiny-v1" |
|
|
print(f"تحميل نموذج تحويل النص إلى كلام: {tts_model_repo_id}") |
|
|
tts_model = ParlerTTSForConditionalGeneration.from_pretrained(tts_model_repo_id).to(device) |
|
|
tts_feature_extractor = AutoFeatureExtractor.from_pretrained(tts_model_repo_id) |
|
|
print("تم تحميل مكونات تحويل النص إلى كلام.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_text_response(prompt_text): |
|
|
try: |
|
|
full_prompt = f"السؤال: {prompt_text}\nالإجابة الودية والواضحة:" |
|
|
inputs = text_tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True, max_length=512).to(text_model.device) |
|
|
outputs = text_model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=150, |
|
|
temperature=0.7, |
|
|
top_k=50, |
|
|
do_sample=True, |
|
|
pad_token_id=text_tokenizer.eos_token_id, |
|
|
no_repeat_ngram_size=2 |
|
|
) |
|
|
response_text = text_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) |
|
|
return response_text.strip() |
|
|
except Exception as e: |
|
|
print(f"خطأ في توليد النص: {e}") |
|
|
return f"خطأ في معالجة النص: {str(e)}" |
|
|
|
|
|
|
|
|
def analyze_image(pil_image, question_text=None): |
|
|
try: |
|
|
if pil_image is None: |
|
|
return "الرجاء رفع صورة أولاً." |
|
|
if isinstance(pil_image, np.ndarray): |
|
|
pil_image = Image.fromarray(pil_image).convert("RGB") |
|
|
else: |
|
|
pil_image = pil_image.convert("RGB") |
|
|
|
|
|
if not question_text or question_text.strip() == "": |
|
|
|
|
|
|
|
|
|
|
|
question_text = "What is in this image?" |
|
|
|
|
|
|
|
|
encoding = image_processor(pil_image, question_text, return_tensors="pt").to(image_model.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = image_model(**encoding) |
|
|
|
|
|
logits = outputs.logits |
|
|
idx = logits.argmax(-1).item() |
|
|
response_text = image_model.config.id2label[idx] |
|
|
|
|
|
return response_text |
|
|
except Exception as e: |
|
|
print(f"خطأ في تحليل الصورة: {e}") |
|
|
return f"خطأ في تحليل الصورة: {str(e)}" |
|
|
|
|
|
|
|
|
def process_audio(audio_input): |
|
|
try: |
|
|
if audio_input is None: |
|
|
return "الرجاء تسجيل الصوت أولاً.", "", (16000, np.array([], dtype=np.int16)) |
|
|
|
|
|
sample_rate, audio_data = audio_input |
|
|
|
|
|
if audio_data.dtype != np.float32: |
|
|
audio_data = audio_data.astype(np.float32) |
|
|
if np.max(np.abs(audio_data)) > 0: |
|
|
audio_data = audio_data / np.max(np.abs(audio_data)) |
|
|
else: |
|
|
return "تم استقبال صوت صامت.", "", (16000, np.array([], dtype=np.int16)) |
|
|
|
|
|
input_features = stt_processor(audio_data, sampling_rate=sample_rate, return_tensors="pt").input_features.to(device) |
|
|
predicted_ids = stt_model.generate(input_features, language="ar") |
|
|
transcription = stt_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0].strip() |
|
|
|
|
|
if not transcription: |
|
|
return "لم يتمكن النموذج من استخراج نص من الصوت.", "", (16000, np.array([], dtype=np.int16)) |
|
|
|
|
|
text_response = generate_text_response(transcription) |
|
|
|
|
|
prompt = text_response |
|
|
with torch.no_grad(): |
|
|
generation_output = tts_model.generate(input_ids=None, |
|
|
prompt=prompt, |
|
|
do_sample=True, |
|
|
temperature=1.0).cpu().numpy().squeeze() |
|
|
audio_output_np = generation_output |
|
|
tts_sample_rate = tts_model.config.sampling_rate |
|
|
|
|
|
return transcription, text_response, (tts_sample_rate, audio_output_np) |
|
|
except Exception as e: |
|
|
print(f"خطأ في معالجة الصوت: {e}") |
|
|
empty_audio_data = np.array([], dtype=np.float32) |
|
|
return f"خطأ في معالجة الصوت: {str(e)}", "", (16000, empty_audio_data) |
|
|
|
|
|
|
|
|
def process_file(file_obj): |
|
|
try: |
|
|
if file_obj is None: |
|
|
return "الرجاء رفع ملف أولاً." |
|
|
file_path = file_obj.name |
|
|
text_content = "" |
|
|
if file_path.endswith(".pdf"): |
|
|
with fitz.open(file_path) as doc: |
|
|
text_content = "\n".join(page.get_text() for page in doc) |
|
|
elif file_path.endswith((".xlsx", ".xls")): |
|
|
df = pd.read_excel(file_path) |
|
|
text_content = df.to_string() |
|
|
elif file_path.endswith(".csv"): |
|
|
df = pd.read_csv(file_path) |
|
|
text_content = df.to_string() |
|
|
else: |
|
|
return "❌ نوع الملف غير مدعوم حالياً (يدعم PDF, Excel, CSV)." |
|
|
|
|
|
if not text_content.strip(): |
|
|
return "الملف فارغ أو لا يمكن قراءة محتواه النصي." |
|
|
|
|
|
max_context_len = 1000 |
|
|
if len(text_content) > max_context_len: |
|
|
text_content = text_content[:max_context_len] + "... [المحتوى تم اختصاره]" |
|
|
|
|
|
response = generate_text_response(f"لخص المحتوى التالي من الملف: \n\n{text_content}") |
|
|
return response |
|
|
except Exception as e: |
|
|
print(f"خطأ في معالجة الملف: {e}") |
|
|
return f"خطأ في قراءة الملف: {str(e)}" |
|
|
|
|
|
|
|
|
with gr.Blocks(css=".gradio-container {background-color: #f0f4f8; font-family: Arial; color: #333; padding: 20px;}", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown("# 🤖 Kemo Chat V3.2 - مساعد ذكي متعدد الوسائط (نماذج خفيفة الوزن - ViLT VQA)") |
|
|
gr.Markdown("🎯 تفاعل معي عبر النصوص، الصور، الصوت أو الملفات! (باستخدام نماذج أقل استهلاكًا للذاكرة).") |
|
|
gr.Markdown("📁 يدعم الملفات: PDF، Excel، CSV\n🖼️ يدعم الإجابة على الأسئلة حول الصور (VQA)\n🎙️ تحويل الصوت إلى نص والرد صوتياً") |
|
|
|
|
|
with gr.Tab("💬 المحادثة النصية"): |
|
|
text_input = gr.Textbox(label="اكتب سؤالك أو رسالتك هنا", lines=3) |
|
|
text_output = gr.Textbox(label="الرد", lines=5, interactive=False) |
|
|
text_submit = gr.Button("إرسال", variant="primary") |
|
|
text_submit.click(fn=generate_text_response, inputs=text_input, outputs=text_output) |
|
|
|
|
|
with gr.Tab("🖼️ تحليل الصور (سؤال وجواب)"): |
|
|
gr.Markdown("ارفع صورة واطرح سؤالاً عنها.") |
|
|
with gr.Row(): |
|
|
image_input = gr.Image(type="pil", label="ارفع صورة") |
|
|
with gr.Column(): |
|
|
image_question = gr.Textbox(label="اطرح سؤالاً عن الصورة (مطلوب لـ ViLT)", lines=2, placeholder="مثال: What color is the car?") |
|
|
image_output = gr.Textbox(label="الإجابة", lines=5, interactive=False) |
|
|
image_submit = gr.Button("تحليل الصورة", variant="primary") |
|
|
image_submit.click(fn=analyze_image, inputs=[image_input, image_question], outputs=image_output) |
|
|
|
|
|
with gr.Tab("🎤 التفاعل الصوتي"): |
|
|
gr.Markdown("سجّل رسالة صوتية، سيتم تحويلها إلى نص، ثم الرد عليها نصيًا وصوتيًا.") |
|
|
audio_input = gr.Audio(sources=["microphone"], type="numpy", label="سجّل رسالتك الصوتية") |
|
|
with gr.Row(): |
|
|
audio_transcription = gr.Textbox(label="النص المستخرج من صوتك", interactive=False, lines=2) |
|
|
audio_text_response = gr.Textbox(label="الرد النصي على رسالتك", interactive=False, lines=3) |
|
|
audio_output_player = gr.Audio(label="الرد الصوتي من المساعد", type="numpy", interactive=False) |
|
|
audio_submit = gr.Button("معالجة الصوت", variant="primary") |
|
|
audio_submit.click(fn=process_audio, |
|
|
inputs=audio_input, |
|
|
outputs=[audio_transcription, audio_text_response, audio_output_player]) |
|
|
|
|
|
with gr.Tab("📄 تحليل الملفات"): |
|
|
gr.Markdown("ارفع ملف (PDF, Excel, CSV) وسأقوم بتلخيص محتواه أو الإجابة على أسئلتك حوله.") |
|
|
file_input = gr.File(label="ارفع ملفك (PDF, Excel, CSV)", file_types=[".pdf", ".xls", ".xlsx", ".csv"]) |
|
|
file_output = gr.Textbox(label="الرد على محتوى الملف", lines=5, interactive=False) |
|
|
file_submit = gr.Button("تحليل الملف", variant="primary") |
|
|
file_submit.click(fn=process_file, inputs=file_input, outputs=file_output) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("Launching Gradio Demo (Lightweight Models with ViLT VQA)...") |
|
|
demo.launch(share=True) |
|
|
|
|
|
|