Spaces:
Running
Running
| import os | |
| import re | |
| import gradio as gr | |
| import spaces | |
| # CRITICAL: Disable PyTorch compiler settings | |
| os.environ["PYTORCH_NO_CUDA_MEMORY_CACHING"] = "1" | |
| os.environ["TORCH_COMPILE_DISABLE"] = "1" | |
| os.environ["TORCH_INDUCTOR_DISABLE"] = "1" | |
| os.environ["TORCHINDUCTOR_DISABLE_CUDAGRAPHS"] = "1" | |
| os.environ["CUDA_LAUNCH_BLOCKING"] = "1" | |
| os.environ["TORCH_USE_CUDA_DSA"] = "0" | |
| # Import torch and disable dynamo | |
| import torch | |
| if hasattr(torch, "_dynamo"): | |
| if hasattr(torch._dynamo, "config"): | |
| torch._dynamo.config.suppress_errors = True | |
| if hasattr(torch._dynamo, "disable"): | |
| torch._dynamo.disable() | |
| print("Disabled torch._dynamo") | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| # Global variables for model and tokenizer | |
| global_model = None | |
| global_tokenizer = None | |
| # Model ID | |
| model_id = "CohereForAI/c4ai-command-r7b-arabic-02-2025" | |
| # Get token from environment | |
| hf_token = os.environ.get("HF_TOKEN") | |
| # Basic function to load models | |
| def load_models(): | |
| global global_model, global_tokenizer | |
| if global_model is not None and global_tokenizer is not None: | |
| return global_tokenizer, global_model | |
| if not hf_token: | |
| return None, None | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token) | |
| device_map = "auto" if torch.cuda.is_available() else "cpu" | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| token=hf_token, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| device_map=device_map, | |
| use_cache=True, | |
| use_flash_attention_2=False, | |
| _attn_implementation="eager" | |
| ) | |
| global_model = model | |
| global_tokenizer = tokenizer | |
| return tokenizer, model | |
| except Exception as e: | |
| print(f"Error loading model: {str(e)}") | |
| return None, None | |
| # Enhanced clean response function with better metadata cleaning | |
| def clean_response(text): | |
| # Step 1: Aggressively remove the common metadata pattern at the beginning | |
| text = re.sub(r'^\s*[\?؟]\s*موضوع:.*?التالي:\s*[^؟?]*[؟?]\s*[^\n]*\d{4},\s*\d{1,2}:\d{2}\s*[apm]+\s*', '', text) | |
| text = re.sub(r'^\s*[\?؟]\s*منتديات.*?حائل.*?[^\n]*\d{4},\s*\d{1,2}:\d{2}', '', text) | |
| # Step 2: Remove date and timestamp patterns | |
| text = re.sub(r'\d{1,2}\s+[^ ]+\s+\d{4}\s*[-,]\s*\d{1,2}:\d{2}\s*[صمaApP][مmMnN]?', '', text) | |
| # Step 3: Remove forum metadata | |
| text = re.sub(r'^\s*[\?؟]\s*[^\n]+:\s*[^\n]+\?\s*', '', text) | |
| text = re.sub(r'منتديات.*?نور[^\n]*$', '', text) | |
| # Step 4: Remove website references | |
| text = re.sub(r'[-–—]\s*موقع\s+[^\n]+', '', text) | |
| text = re.sub(r'[-–—]\s*[^\n]*المصطبه[^\n]*', '', text) | |
| # Step 5: Remove unrelated questions that might appear | |
| text = re.sub(r'من هو [^؟?]+\؟', '', text) | |
| text = re.sub(r'من هي [^؟?]+\؟', '', text) | |
| text = re.sub(r'ما هو [^؟?]+\؟', '', text) | |
| text = re.sub(r'ما هي [^؟?]+\؟', '', text) | |
| # Step 6: Clean up format and spacing | |
| text = re.sub(r'\s+', ' ', text).strip() | |
| # Step 7: If text begins with punctuation, clean it | |
| text = re.sub(r'^[:.،,؛;-]+\s*', '', text) | |
| return text | |
| # Generate text function with GPU access | |
| def generate_text(prompt): | |
| if not prompt.strip(): | |
| return "يرجى إدخال سؤال." | |
| try: | |
| tokenizer, model = load_models() | |
| if tokenizer is None or model is None: | |
| return "خطأ في تحميل النموذج." | |
| # Using minimal prompt without system context | |
| full_prompt = prompt | |
| # Tokenize and generate | |
| inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device) | |
| with torch.inference_mode(): | |
| outputs = model.generate( | |
| input_ids=inputs.input_ids, | |
| max_new_tokens=300, | |
| temperature=0.1, | |
| do_sample=False | |
| ) | |
| # Get only new content | |
| generated_text = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) | |
| # Apply enhanced cleaning to remove metadata | |
| final_text = clean_response(generated_text) | |
| return final_text | |
| except Exception as e: | |
| return f"خطأ في توليد النص: {str(e)}" | |
| # Example setters | |
| def set_example1(): | |
| return 'من كتب قصيدة "على قدر أهل العزم تأتي العزائم"؟' | |
| def set_example2(): | |
| return "ما هي عاصمة السعودية؟" | |
| def set_example3(): | |
| return "اشرح مفهوم الذكاء الاصطناعي" | |
| def set_example4(): | |
| return "ماهو شهر رمضان؟" | |
| # Create custom CSS with enhanced fonts and distinctive buttons | |
| custom_css = """ | |
| /* Import improved Arabic fonts from Google */ | |
| @import url('https://fonts.googleapis.com/css2?family=Tajawal:wght@400;500;700&display=swap'); | |
| @import url('https://fonts.googleapis.com/css2?family=IBM+Plex+Sans+Arabic:wght@400;500;600&display=swap'); | |
| /* Enhanced styles for Arabic Cohere model interface */ | |
| :root { | |
| --primary-color: #1F4287; | |
| --secondary-color: #278EA5; | |
| --bg-color: #f9fafb; | |
| --border-color: #d1d5db; | |
| --btn-primary: #2563EB; /* Bright blue */ | |
| --btn-secondary: #6B7280; /* Gray */ | |
| --btn-primary-hover: #1D4ED8; | |
| --btn-secondary-hover: #4B5563; | |
| --example-btn-bg: #F3F4F6; | |
| --example-btn-border: #D1D5DB; | |
| --text-color: #000000; | |
| } | |
| /* Base styles */ | |
| body, html { | |
| font-family: 'Tajawal', 'IBM Plex Sans Arabic', 'Arial', sans-serif !important; | |
| margin: 0; | |
| padding: 0; | |
| background-color: var(--bg-color); | |
| overflow-x: hidden; | |
| } | |
| /* Typography */ | |
| h1, h2, h3, button, label { | |
| font-family: 'Tajawal', 'IBM Plex Sans Arabic', 'Arial', sans-serif !important; | |
| color: var(--primary-color); | |
| text-align: center; | |
| font-weight: 700 !important; | |
| } | |
| /* Fix input and output containers */ | |
| .input-container, .output-container { | |
| border: 1px solid var(--border-color); | |
| border-radius: 8px; | |
| padding: 10px; | |
| margin-bottom: 15px; | |
| background-color: white; | |
| } | |
| /* CRITICAL FIX: Make sure text is visible in textboxes with better fonts */ | |
| textarea, .output-text { | |
| font-family: 'IBM Plex Sans Arabic', 'Tajawal', 'Arial', sans-serif !important; | |
| color: black !important; | |
| background-color: white !important; | |
| border: 1px solid #d1d5db !important; | |
| padding: 12px !important; | |
| font-size: 16px !important; | |
| line-height: 1.6 !important; | |
| border-radius: 8px !important; | |
| width: 100% !important; | |
| direction: rtl !important; | |
| min-height: 80px !important; | |
| font-weight: 500 !important; | |
| letter-spacing: 0.2px !important; | |
| } | |
| /* Ensure text and placeholder are visible */ | |
| textarea::placeholder { | |
| color: #9ca3af !important; | |
| opacity: 1 !important; | |
| font-family: 'IBM Plex Sans Arabic', 'Tajawal', 'Arial', sans-serif !important; | |
| } | |
| /* Button styling to match current design but more distinct */ | |
| button { | |
| border-radius: 8px !important; | |
| padding: 10px 20px !important; | |
| font-weight: 600 !important; | |
| transition: all 0.2s ease !important; | |
| cursor: pointer !important; | |
| text-align: center !important; | |
| font-size: 15px !important; | |
| box-shadow: 0 1px 3px rgba(0,0,0,0.1) !important; | |
| margin: 5px !important; | |
| } | |
| /* SPECIAL STYLING FOR GENERATE/CLEAR BUTTONS */ | |
| /* Generate Button - Distinctive styling */ | |
| #generate-btn { | |
| background: linear-gradient(135deg, #2563EB, #3B82F6) !important; | |
| color: white !important; | |
| border: none !important; | |
| font-weight: 700 !important; | |
| padding: 12px 24px !important; | |
| box-shadow: 0 4px 6px rgba(37, 99, 235, 0.25) !important; | |
| transform: translateY(0) !important; | |
| font-size: 16px !important; | |
| } | |
| #generate-btn:hover { | |
| background: linear-gradient(135deg, #1D4ED8, #2563EB) !important; | |
| box-shadow: 0 6px 8px rgba(37, 99, 235, 0.3) !important; | |
| transform: translateY(-2px) !important; | |
| } | |
| /* Clear Button - Distinctive styling */ | |
| #clear-btn { | |
| background-color: #F3F4F6 !important; | |
| color: #4B5563 !important; | |
| border: 1px solid #D1D5DB !important; | |
| font-weight: 600 !important; | |
| padding: 12px 24px !important; | |
| font-size: 16px !important; | |
| } | |
| #clear-btn:hover { | |
| background-color: #E5E7EB !important; | |
| color: #374151 !important; | |
| } | |
| /* Example buttons styling */ | |
| .example-btn { | |
| background-color: var(--example-btn-bg) !important; | |
| border: 1px solid var(--example-btn-border) !important; | |
| color: var(--primary-color) !important; | |
| padding: 8px 12px !important; | |
| border-radius: 6px !important; | |
| margin: 4px !important; | |
| font-size: 14px !important; | |
| font-weight: 500 !important; | |
| } | |
| .example-btn:hover { | |
| background-color: #E5E7EB !important; | |
| border-color: #9CA3AF !important; | |
| } | |
| /* Layout containers */ | |
| .container { | |
| max-width: 1200px; | |
| margin: 0 auto; | |
| padding: 20px; | |
| } | |
| .row { | |
| display: flex; | |
| gap: 20px; | |
| margin-bottom: 15px; | |
| } | |
| .col { | |
| flex: 1; | |
| } | |
| /* Explicitly force element visibility */ | |
| #input-text, #output-text { | |
| visibility: visible !important; | |
| display: block !important; | |
| opacity: 1 !important; | |
| } | |
| /* Improved labels */ | |
| label { | |
| font-size: 16px !important; | |
| font-weight: 600 !important; | |
| margin-bottom: 5px !important; | |
| display: block !important; | |
| text-align: right !important; | |
| color: #1F2937 !important; | |
| } | |
| """ | |
| # Create a Gradio interface that matches the current design with better fonts | |
| with gr.Blocks(css=custom_css, title="Cohere Arabic Model") as demo: | |
| gr.Markdown("""# ⭐ نموذج أرحب للغة العربية | Command R7B Arabic Model""") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_text = gr.Textbox( | |
| label="النص الإدخال | Input Prompt", | |
| placeholder="أدخل سؤالك باللغة العربية هنا...", | |
| lines=4, | |
| elem_id="input-text" | |
| ) | |
| gr.Markdown("### أمثلة سريعة | Quick Examples") | |
| with gr.Row(): | |
| ex1 = gr.Button('من كتب قصيدة "على قدر أهل العزم تأتي العزائم"؟', elem_classes=["example-btn"]) | |
| ex2 = gr.Button("ما هي عاصمة السعودية؟", elem_classes=["example-btn"]) | |
| with gr.Row(): | |
| ex3 = gr.Button("اشرح مفهوم الذكاء الاصطناعي", elem_classes=["example-btn"]) | |
| ex4 = gr.Button("ماهو شهر رمضان؟", elem_classes=["example-btn"]) | |
| with gr.Row(): | |
| submit_btn = gr.Button("توليد النص | Generate", elem_id="generate-btn") | |
| clear_btn = gr.Button("مسح | Clear", elem_id="clear-btn") | |
| with gr.Column(): | |
| output_text = gr.Textbox( | |
| label="النص المولد | Generated Text", | |
| lines=10, | |
| elem_id="output-text" | |
| ) | |
| # Set up the event handlers | |
| ex1.click(fn=set_example1, outputs=input_text) | |
| ex2.click(fn=set_example2, outputs=input_text) | |
| ex3.click(fn=set_example3, outputs=input_text) | |
| ex4.click(fn=set_example4, outputs=input_text) | |
| submit_btn.click(fn=generate_text, inputs=input_text, outputs=output_text) | |
| clear_btn.click(fn=lambda: "", outputs=input_text) | |
| # Launch the demo | |
| if __name__ == "__main__": | |
| demo.launch() |