Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -23,7 +23,6 @@ except ImportError:
|
|
| 23 |
MODEL_PATH = os.getenv("MODEL_PATH", "speakleash/sojka2")
|
| 24 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 25 |
LABELS = ["self-harm", "hate", "vulgar", "sex", "crime"]
|
| 26 |
-
MAX_SEQ_LENGTH = 512
|
| 27 |
# Thresholds are now hardcoded
|
| 28 |
THRESHOLDS = {
|
| 29 |
"self-harm": 0.5,
|
|
@@ -61,9 +60,6 @@ def load_model_and_tokenizer(model_path: str, device: str) -> Tuple[AutoModelFor
|
|
| 61 |
is_peft = os.path.exists(os.path.join(model_path, 'adapter_config.json'))
|
| 62 |
if PeftModel and is_peft:
|
| 63 |
logger.info("PEFT adapter detected. Loading base model and attaching adapter.")
|
| 64 |
-
# Logic to load PEFT model (kept for robustness)
|
| 65 |
-
# This part assumes adapter_config.json contains base_model_name_or_path
|
| 66 |
-
# Simplified for clarity, ensure your PEFT config is correct if you use it.
|
| 67 |
try:
|
| 68 |
from peft import PeftConfig
|
| 69 |
peft_config = PeftConfig.from_pretrained(model_path)
|
|
@@ -107,7 +103,9 @@ def predict(text: str) -> Dict[str, Any]:
|
|
| 107 |
|
| 108 |
with torch.no_grad():
|
| 109 |
outputs = model(**inputs)
|
| 110 |
-
|
|
|
|
|
|
|
| 111 |
|
| 112 |
clipped_values = np.clip(predicted_values, 0.0, 1.0)
|
| 113 |
return {label: float(score) for label, score in zip(LABELS, clipped_values)}
|
|
@@ -132,7 +130,6 @@ def gradio_predict(text: str) -> Tuple[str, Dict[str, float]]:
|
|
| 132 |
if not unsafe_categories:
|
| 133 |
verdict = "✅ Komunikat jest bezpieczny."
|
| 134 |
else:
|
| 135 |
-
# Sort by score to show the most likely category first
|
| 136 |
highest_unsafe_category = max(unsafe_categories, key=unsafe_categories.get)
|
| 137 |
verdict = f"⚠️ Wykryto potencjalnie szkodliwe treści w kategorii: {highest_unsafe_category.upper()}"
|
| 138 |
|
|
@@ -140,8 +137,8 @@ def gradio_predict(text: str) -> Tuple[str, Dict[str, float]]:
|
|
| 140 |
|
| 141 |
# --- Gradio Interface ---
|
| 142 |
|
| 143 |
-
# Custom theme inspired by the provided image
|
| 144 |
-
theme = gr.themes.Default
|
| 145 |
primary_hue=gr.themes.colors.blue,
|
| 146 |
secondary_hue=gr.themes.colors.indigo,
|
| 147 |
neutral_hue=gr.themes.colors.slate,
|
|
@@ -168,7 +165,7 @@ with gr.Blocks(theme=theme, css=".gradio-container {max-width: 960px !important;
|
|
| 168 |
<div style="display: flex; align-items: center; gap: 20px; font-size: 0.9rem;">
|
| 169 |
<a href="#" style="text-decoration: none; color: inherit;">O projekcie</a>
|
| 170 |
<a href="#" style="text-decoration: none; color: inherit;">Opis kategorii</a>
|
| 171 |
-
<button class="gr-button gr-button-primary gr-button-lg"
|
| 172 |
style="background-color: var(--primary-500); color: white; padding: 8px 16px; border-radius: 8px;">
|
| 173 |
Testuj Sójkę
|
| 174 |
</button>
|
|
@@ -200,10 +197,7 @@ with gr.Blocks(theme=theme, css=".gradio-container {max-width: 960px !important;
|
|
| 200 |
label="Wprowadź tekst do analizy",
|
| 201 |
placeholder="Tutaj wpisz tekst..."
|
| 202 |
)
|
| 203 |
-
submit_btn = gr.Button("
|
| 204 |
-
|
| 205 |
-
# Use a more descriptive name for the submit button that matches its function
|
| 206 |
-
submit_btn.value = "Analizuj tekst"
|
| 207 |
|
| 208 |
output_verdict = gr.Label(label="Wynik analizy", value="Czekam na tekst do analizy...")
|
| 209 |
output_scores = gr.Label(label="Szczegółowe wyniki", visible=False)
|
|
|
|
| 23 |
MODEL_PATH = os.getenv("MODEL_PATH", "speakleash/sojka2")
|
| 24 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 25 |
LABELS = ["self-harm", "hate", "vulgar", "sex", "crime"]
|
|
|
|
| 26 |
# Thresholds are now hardcoded
|
| 27 |
THRESHOLDS = {
|
| 28 |
"self-harm": 0.5,
|
|
|
|
| 60 |
is_peft = os.path.exists(os.path.join(model_path, 'adapter_config.json'))
|
| 61 |
if PeftModel and is_peft:
|
| 62 |
logger.info("PEFT adapter detected. Loading base model and attaching adapter.")
|
|
|
|
|
|
|
|
|
|
| 63 |
try:
|
| 64 |
from peft import PeftConfig
|
| 65 |
peft_config = PeftConfig.from_pretrained(model_path)
|
|
|
|
| 103 |
|
| 104 |
with torch.no_grad():
|
| 105 |
outputs = model(**inputs)
|
| 106 |
+
# Using sigmoid for multi-label classification outputs
|
| 107 |
+
probabilities = torch.sigmoid(outputs.logits)
|
| 108 |
+
predicted_values = probabilities.cpu().numpy()[0]
|
| 109 |
|
| 110 |
clipped_values = np.clip(predicted_values, 0.0, 1.0)
|
| 111 |
return {label: float(score) for label, score in zip(LABELS, clipped_values)}
|
|
|
|
| 130 |
if not unsafe_categories:
|
| 131 |
verdict = "✅ Komunikat jest bezpieczny."
|
| 132 |
else:
|
|
|
|
| 133 |
highest_unsafe_category = max(unsafe_categories, key=unsafe_categories.get)
|
| 134 |
verdict = f"⚠️ Wykryto potencjalnie szkodliwe treści w kategorii: {highest_unsafe_category.upper()}"
|
| 135 |
|
|
|
|
| 137 |
|
| 138 |
# --- Gradio Interface ---
|
| 139 |
|
| 140 |
+
# Custom theme inspired by the provided image - THIS IS THE CORRECTED LINE
|
| 141 |
+
theme = gr.themes.Default(
|
| 142 |
primary_hue=gr.themes.colors.blue,
|
| 143 |
secondary_hue=gr.themes.colors.indigo,
|
| 144 |
neutral_hue=gr.themes.colors.slate,
|
|
|
|
| 165 |
<div style="display: flex; align-items: center; gap: 20px; font-size: 0.9rem;">
|
| 166 |
<a href="#" style="text-decoration: none; color: inherit;">O projekcie</a>
|
| 167 |
<a href="#" style="text-decoration: none; color: inherit;">Opis kategorii</a>
|
| 168 |
+
<button id="test-sojka-btn" class="gr-button gr-button-primary gr-button-lg"
|
| 169 |
style="background-color: var(--primary-500); color: white; padding: 8px 16px; border-radius: 8px;">
|
| 170 |
Testuj Sójkę
|
| 171 |
</button>
|
|
|
|
| 197 |
label="Wprowadź tekst do analizy",
|
| 198 |
placeholder="Tutaj wpisz tekst..."
|
| 199 |
)
|
| 200 |
+
submit_btn = gr.Button("Analizuj tekst", variant="primary")
|
|
|
|
|
|
|
|
|
|
| 201 |
|
| 202 |
output_verdict = gr.Label(label="Wynik analizy", value="Czekam na tekst do analizy...")
|
| 203 |
output_scores = gr.Label(label="Szczegółowe wyniki", visible=False)
|