Spaces:
Paused
Paused
add prompt safety classifier
Browse files- app.py +57 -4
- requirements.txt +2 -0
app.py
CHANGED
|
@@ -9,8 +9,11 @@ import os
|
|
| 9 |
from dotenv import load_dotenv
|
| 10 |
import json
|
| 11 |
from PIL import Image, ImageDraw, ImageFont
|
|
|
|
|
|
|
| 12 |
import uuid
|
| 13 |
import threading
|
|
|
|
| 14 |
|
| 15 |
# Load environment variables first
|
| 16 |
load_dotenv()
|
|
@@ -41,6 +44,38 @@ BACKENDS = {
|
|
| 41 |
},
|
| 42 |
}
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
class BackendStatus:
|
| 46 |
|
|
@@ -600,6 +635,24 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
| 600 |
)
|
| 601 |
return
|
| 602 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 603 |
# Status message
|
| 604 |
status_message = f"๐ PROCESSING: '{prompt}'"
|
| 605 |
|
|
@@ -749,10 +802,10 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
| 749 |
|
| 750 |
# Launch with increased max_threads
|
| 751 |
if __name__ == "__main__":
|
| 752 |
-
|
| 753 |
-
|
| 754 |
-
|
| 755 |
-
|
| 756 |
demo.queue(max_size=4).launch(
|
| 757 |
server_name="0.0.0.0",
|
| 758 |
max_threads=16, # Increase thread count for better concurrency
|
|
|
|
| 9 |
from dotenv import load_dotenv
|
| 10 |
import json
|
| 11 |
from PIL import Image, ImageDraw, ImageFont
|
| 12 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 13 |
+
import torch
|
| 14 |
import uuid
|
| 15 |
import threading
|
| 16 |
+
import functools
|
| 17 |
|
| 18 |
# Load environment variables first
|
| 19 |
load_dotenv()
|
|
|
|
| 44 |
},
|
| 45 |
}
|
| 46 |
|
| 47 |
+
MODEL_URL = "MichalMlodawski/nsfw-text-detection-large"
|
| 48 |
+
TITLE = "๐ผ๏ธ๐ Image Prompt Safety Classifier ๐ก๏ธ"
|
| 49 |
+
DESCRIPTION = "โจ Enter an image generation prompt to classify its safety level! โจ"
|
| 50 |
+
|
| 51 |
+
# Load model and tokenizer
|
| 52 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_URL)
|
| 53 |
+
model = AutoModelForSequenceClassification.from_pretrained(MODEL_URL)
|
| 54 |
+
|
| 55 |
+
# Define class names with emojis and detailed descriptions
|
| 56 |
+
CLASS_NAMES = {
|
| 57 |
+
0: "โ
SAFE - This prompt is appropriate and harmless.",
|
| 58 |
+
1: "โ ๏ธ QUESTIONABLE - This prompt may require further review.",
|
| 59 |
+
2: "๐ซ UNSAFE - This prompt is likely to generate inappropriate content."
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@functools.lru_cache(maxsize=128)
|
| 64 |
+
def classify_text(text):
|
| 65 |
+
inputs = tokenizer(text,
|
| 66 |
+
return_tensors="pt",
|
| 67 |
+
truncation=True,
|
| 68 |
+
padding=True,
|
| 69 |
+
max_length=1024)
|
| 70 |
+
|
| 71 |
+
with torch.no_grad():
|
| 72 |
+
outputs = model(**inputs)
|
| 73 |
+
|
| 74 |
+
logits = outputs.logits
|
| 75 |
+
predicted_class = torch.argmax(logits, dim=1).item()
|
| 76 |
+
|
| 77 |
+
return predicted_class, CLASS_NAMES[predicted_class]
|
| 78 |
+
|
| 79 |
|
| 80 |
class BackendStatus:
|
| 81 |
|
|
|
|
| 635 |
)
|
| 636 |
return
|
| 637 |
|
| 638 |
+
# Check if the prompt is safe
|
| 639 |
+
classification, message = classify_text(prompt)
|
| 640 |
+
if classification != 0:
|
| 641 |
+
# Handle unsafe prompt case
|
| 642 |
+
yield (
|
| 643 |
+
message,
|
| 644 |
+
message,
|
| 645 |
+
gr.update(visible=True),
|
| 646 |
+
gr.update(visible=False),
|
| 647 |
+
None,
|
| 648 |
+
None,
|
| 649 |
+
None,
|
| 650 |
+
None,
|
| 651 |
+
session_id, # Return the session ID
|
| 652 |
+
None,
|
| 653 |
+
)
|
| 654 |
+
return
|
| 655 |
+
|
| 656 |
# Status message
|
| 657 |
status_message = f"๐ PROCESSING: '{prompt}'"
|
| 658 |
|
|
|
|
| 802 |
|
| 803 |
# Launch with increased max_threads
|
| 804 |
if __name__ == "__main__":
|
| 805 |
+
demo.queue(max_size=50).launch(
|
| 806 |
+
server_name="0.0.0.0",
|
| 807 |
+
max_threads=16, # Increase thread count for better concurrency
|
| 808 |
+
)
|
| 809 |
demo.queue(max_size=4).launch(
|
| 810 |
server_name="0.0.0.0",
|
| 811 |
max_threads=16, # Increase thread count for better concurrency
|
requirements.txt
CHANGED
|
@@ -3,3 +3,5 @@ aiohttp
|
|
| 3 |
plotly
|
| 4 |
python-dotenv
|
| 5 |
pydantic==2.8.2
|
|
|
|
|
|
|
|
|
| 3 |
plotly
|
| 4 |
python-dotenv
|
| 5 |
pydantic==2.8.2
|
| 6 |
+
torch
|
| 7 |
+
transformers==4.37.2
|