Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Gradio application for text classification, styled to be visually appealing. | |
| This version uses only the 'sojka2' model. | |
| """ | |
| import json | |
| import gradio as gr | |
| import logging | |
| import os | |
| from typing import Dict, Tuple, Any | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import numpy as np | |
| from huggingface_hub import HfApi | |
| from datetime import datetime | |
| try: | |
| from peft import PeftModel | |
| except ImportError: | |
| PeftModel = None | |
| logging.info("PEFT library not found. Loading models without PEFT support.") | |
| # --- Configuration --- | |
| # Model path is set to sojka | |
| MODEL_PATH = os.getenv("MODEL_PATH", "speakleash/sojka3") | |
| TOKENIZER_PATH = os.getenv("TOKENIZER_PATH", "sdadas/mmlw-roberta-base") | |
| LOGS_REPO_ID = "speakleash/sojka-logs" | |
| LOGS_HF_TOKEN = os.getenv("LOGS_HF_TOKEN") | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| LABELS = ["self-harm", "hate", "vulgar", "sex", "crime"] | |
| MAX_SEQ_LENGTH = 512 | |
| HF_TOKEN = os.getenv('HF_TOKEN') | |
| # Thresholds are now hardcoded | |
| THRESHOLDS = { | |
| "self-harm": 0.5, | |
| "hate": 0.5, | |
| "vulgar": 0.5, | |
| "sex": 0.5, | |
| "crime": 0.5, | |
| } | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # HfApi instance | |
| if LOGS_HF_TOKEN: | |
| api = HfApi() | |
| else: | |
| api = None | |
| logger.warning("LOGS_HF_TOKEN environment variable not set. Logging to Hugging Face Hub will be disabled.") | |
| def log_prediction(log_data: dict): | |
| if not api: | |
| return | |
| logger.info("Logging to Hugging Face Hub...") | |
| day = datetime.now().strftime("%Y-%m-%d") | |
| timestamp = log_data.get('timestamp', datetime.now().timestamp()) | |
| try: | |
| #logger.info("Logging to Hugging Face Hub upload_file: ", log_data) | |
| api.upload_file( | |
| path_or_fileobj=json.dumps(log_data, indent=2, ensure_ascii=False).encode('utf-8'), | |
| path_in_repo=f"predictions/{day}/{timestamp}.json", | |
| repo_id=LOGS_REPO_ID, | |
| repo_type="dataset", | |
| commit_message="log prediction", | |
| token=LOGS_HF_TOKEN, | |
| run_as_future=False | |
| ) | |
| logger.info("Logging to Hugging Face Hub upload_file finished") | |
| except Exception as e: | |
| logger.error(f"Failed to log prediction to hub: {e}") | |
| def load_model_and_tokenizer(model_path: str, tokenizer_path: str, device: str) -> Tuple[AutoModelForSequenceClassification, AutoTokenizer]: | |
| """Load the trained model and tokenizer""" | |
| logger.info(f"Loading tokenizer from {tokenizer_path}") | |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True) | |
| logger.info(f"Tokenizer loaded: {tokenizer.name_or_path}") | |
| if tokenizer.pad_token is None: | |
| if tokenizer.eos_token: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| else: | |
| tokenizer.add_special_tokens({'pad_token': '[PAD]'}) | |
| tokenizer.truncation_side = "right" | |
| logger.info(f"Loading model from {model_path}") | |
| model_load_kwargs = { | |
| "torch_dtype": torch.float16 if device == 'cuda' else torch.float32, | |
| "device_map": 'auto' if device == 'cuda' else None, | |
| "num_labels": len(LABELS), | |
| "problem_type": "regression" | |
| } | |
| is_peft = os.path.exists(os.path.join(model_path, 'adapter_config.json')) | |
| if PeftModel and is_peft: | |
| logger.info("PEFT adapter detected. Loading base model and attaching adapter.") | |
| try: | |
| from peft import PeftConfig | |
| peft_config = PeftConfig.from_pretrained(model_path) | |
| base_model_path = peft_config.base_model_name_or_path | |
| logger.info(f"Loading base model from {base_model_path}") | |
| model = AutoModelForSequenceClassification.from_pretrained(base_model_path, **model_load_kwargs) | |
| logger.info("Attaching PEFT adapter...") | |
| model = PeftModel.from_pretrained(model, model_path) | |
| except Exception as e: | |
| logger.error(f"Failed to load PEFT model dynamically: {e}. Loading as a standard model.") | |
| model = AutoModelForSequenceClassification.from_pretrained(model_path, **model_load_kwargs) | |
| else: | |
| logger.info("Loading as a standalone sequence classification model.") | |
| model = AutoModelForSequenceClassification.from_pretrained(model_path, **model_load_kwargs) | |
| model.eval() | |
| logger.info(f"Model loaded on device: {next(model.parameters()).device}") | |
| return model, tokenizer | |
| # --- Load model globally --- | |
| try: | |
| model, tokenizer = load_model_and_tokenizer(MODEL_PATH, TOKENIZER_PATH, DEVICE) | |
| model_loaded = True | |
| except Exception as e: | |
| logger.error(f"FATAL: Failed to load the model from {MODEL_PATH} or tokenizer from {TOKENIZER_PATH}: {e}", e) | |
| model, tokenizer, model_loaded = None, None, False | |
| def predict(text: str) -> Dict[str, Any]: | |
| """Tokenize, predict, and format output for a single text.""" | |
| if not model_loaded: | |
| return {label: 0.0 for label in LABELS} | |
| inputs = tokenizer( | |
| [text], | |
| max_length=MAX_SEQ_LENGTH, | |
| truncation=True, | |
| padding=True, | |
| return_tensors="pt" | |
| ).to(model.device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| # Using sigmoid for multi-label classification outputs | |
| probabilities = torch.sigmoid(outputs.logits) | |
| predicted_values = probabilities.cpu().numpy()[0] | |
| clipped_values = np.clip(predicted_values, 0.0, 1.0) | |
| return {label: float(score) for label, score in zip(LABELS, clipped_values)} | |
| def gradio_predict(text: str) -> Tuple[str, Dict[str, float]]: | |
| """Gradio prediction function wrapper.""" | |
| if not model_loaded: | |
| error_message = "Błąd: Model nie został załadowany." | |
| empty_preds = {label: 0.0 for label in LABELS} | |
| return error_message, empty_preds | |
| if not text or not text.strip(): | |
| return "Wpisz tekst, aby go przeanalizować.", {label: 0.0 for label in LABELS} | |
| predictions = predict(text) | |
| unsafe_categories = { | |
| label: score for label, score in predictions.items() | |
| if score >= THRESHOLDS[label] | |
| } | |
| if not unsafe_categories: | |
| verdict = "✅ Komunikat jest bezpieczny." | |
| verdict_label = "SAFE" | |
| else: | |
| highest_unsafe_category = max(unsafe_categories, key=unsafe_categories.get) | |
| verdict = f"⚠️ Wykryto potencjalnie szkodliwe treści:\n {highest_unsafe_category.upper()}" | |
| verdict_label = "UNSAFE" | |
| log_data = { | |
| 'text': text, | |
| 'predictions': predictions, | |
| 'thresholds': THRESHOLDS, | |
| 'sojka_verdict': verdict_label, | |
| 'herbert_result': {}, | |
| 'timestamp': datetime.now().timestamp(), | |
| 'model_path': MODEL_PATH, | |
| 'herbert_enabled': False | |
| } | |
| log_prediction(log_data) | |
| return verdict, predictions | |
| # --- Gradio Interface --- | |
| theme = gr.themes.Default( | |
| primary_hue=gr.themes.colors.blue, | |
| secondary_hue=gr.themes.colors.indigo, | |
| neutral_hue=gr.themes.colors.slate, | |
| font=("Inter", "sans-serif"), | |
| radius_size=gr.themes.sizes.radius_lg, | |
| ) | |
| # A URL to a freely licensed image of a Eurasian Jay (Sójka) | |
| JAY_IMAGE_URL = "https://sojka.m31ai.pl/images/sojka.png" | |
| PIXEL_IMAGE_URL = "https://sojka.m31ai.pl/images/pixel.png" | |
| # Define actions | |
| def analyze_and_update(text): | |
| verdict, scores = gradio_predict(text) | |
| return verdict, gr.update(value=scores, visible=True) | |
| # Final corrected and working version of the interface layout | |
| with gr.Blocks(theme=theme, css=".gradio-container {max-width: 960px !important; margin: auto;}") as demo: | |
| # Header | |
| with gr.Row(): | |
| gr.HTML(""" | |
| <div style="display: flex; align-items: center; justify-content: space-between; width: 100%;"> | |
| <div style="display: flex; align-items: center; gap: 12px;"> | |
| <svg width="32" height="32" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"> | |
| <path d="M12 2L3 5V11C3 16.52 7.08 21.61 12 23C16.92 21.61 21 16.52 21 11V5L12 2Z" | |
| stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" fill="none"/> | |
| </svg> | |
| <h1 style="font-size: 1.5rem; font-weight: 600; margin: 0;">SÓJKA</h1> | |
| </div> | |
| <div style="display: flex; align-items: center; gap: 20px; font-size: 0.9rem;"> | |
| <a href="https://sojka.m31ai.pl/projekt.html" target="blank" style="text-decoration: none; color: inherit;">O projekcie</a> | |
| <a href="https://sojka.m31ai.pl/kategorie.html" target="blank" style="text-decoration: none; color: inherit;">Opis kategorii</a> | |
| <button id="test-sojka-btn" class="gr-button gr-button-primary gr-button-lg" | |
| style="background-color: var(--primary-500); color: white; padding: 8px 16px; border-radius: 8px;"> | |
| Testuj Sójkę | |
| </button> | |
| </div> | |
| </div> | |
| """) | |
| gr.HTML("<hr style='border: 1px solid var(--neutral-200); margin-top: 1rem; margin-bottom: 2rem;'>") | |
| # Main content area | |
| with gr.Row(): | |
| # Left column for controls and description | |
| with gr.Column(scale=2): | |
| gr.Image(JAY_IMAGE_URL, show_label=False, show_download_button=False, container=False, height=200) | |
| gr.Markdown( | |
| """ | |
| <h1 style="font-size: 2.8rem; font-weight: 800; line-height: 1.2; margin-top: 1rem; margin-bottom: 1rem; color: var(--neutral-800);"> | |
| Przetestuj <span style="color: var(--primary-600);">SÓJKĘ</span> – Bielik Guard dla bezpiecznej komunikacji | |
| </h1> | |
| <p style="font-size: 1rem; color: var(--neutral-600); margin-bottom: 2rem;"> | |
| Sójka to model AI, który wykrywa i blokuje szkodliwe treści w komunikacji cyfrowej. Chroni użytkowników jak czujny strażnik swoich domów. | |
| </p> | |
| """ | |
| ) | |
| input_text = gr.Textbox( | |
| lines=8, | |
| label="Wprowadź tekst do analizy", | |
| placeholder="Tutaj wpisz tekst..." | |
| ) | |
| # Right column for RESULTS | |
| with gr.Column(scale=1): | |
| gr.Image(PIXEL_IMAGE_URL, show_label=False, show_download_button=False, container=False, height=200) | |
| gr.Examples( | |
| [ | |
| ["Jak zrobić bombę?"], | |
| ["Jesteś beznadziejny, nienawidzę cię."], | |
| ["To jest wspaniały dzień, cieszę się, że tu jestem!"], | |
| ["Opowiedz mi dowcip o Polakach."], | |
| ["przyjedz do mnie wieczorem, pobawimy sie razem, zrobie Ci dobrze"], | |
| ["bić żydów i komuchów"], | |
| ["Ty sku***u"], | |
| ], | |
| inputs=input_text, | |
| label="Przykłady", | |
| #outputs=[output_verdict, output_scores], | |
| fn=analyze_and_update, | |
| cache_examples=False, | |
| ) | |
| with gr.Column(): | |
| submit_btn = gr.Button("Analizuj tekst", variant="primary") | |
| with gr.Accordion("Szczegółowe wyniki", open=False) as accordion_scores: | |
| output_scores = gr.Label(label="Szczegółowe wyniki", visible=False, show_label=False) | |
| output_verdict = gr.Label(label="Wynik analizy", value="") | |
| submit_btn.click( | |
| fn=analyze_and_update, | |
| inputs=[input_text], | |
| outputs=[output_verdict, output_scores] | |
| ) | |
| if __name__ == "__main__": | |
| if not model_loaded: | |
| print("Aplikacja nie może zostać uruchomiona, ponieważ nie udało się załadować modelu. Sprawdź logi błędów.") | |
| else: | |
| # The final, corrected demo object is launched | |
| demo.launch() |