Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -15,7 +15,7 @@ from fastapi.templating import Jinja2Templates
|
|
| 15 |
from pydantic import BaseModel, Field
|
| 16 |
from dotenv import load_dotenv
|
| 17 |
from huggingface_hub import snapshot_download
|
| 18 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer,
|
| 19 |
from detoxify import Detoxify
|
| 20 |
from PIL import Image
|
| 21 |
import uvicorn
|
|
@@ -90,9 +90,10 @@ model.eval()
|
|
| 90 |
|
| 91 |
detoxify_model = Detoxify('multilingual')
|
| 92 |
|
| 93 |
-
#
|
| 94 |
print("Loading NSFW image classification model...")
|
| 95 |
-
|
|
|
|
| 96 |
print("NSFW image classification model loaded.")
|
| 97 |
|
| 98 |
MODERATION_SYSTEM_PROMPT = (
|
|
@@ -305,24 +306,33 @@ def classify_text_with_detoxify(text):
|
|
| 305 |
def classify_image(image_data):
|
| 306 |
try:
|
| 307 |
img = Image.open(io.BytesIO(image_data)).convert("RGB")
|
| 308 |
-
# Use the Falconsai NSFW detector
|
| 309 |
-
results = nsfw_classifier(img)
|
| 310 |
|
| 311 |
-
#
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 320 |
|
| 321 |
return {
|
| 322 |
"classification": classification,
|
| 323 |
"label": "NSFW" if classification == 'u' else "SFW",
|
| 324 |
"description": "Content may contain inappropriate or harmful material." if classification == 'u' else "Content appears to be safe and appropriate.",
|
| 325 |
-
"confidence":
|
| 326 |
"nsfw_score": nsfw_score
|
| 327 |
}
|
| 328 |
except Exception as e:
|
|
|
|
| 15 |
from pydantic import BaseModel, Field
|
| 16 |
from dotenv import load_dotenv
|
| 17 |
from huggingface_hub import snapshot_download
|
| 18 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, AutoModelForImageClassification, ViTImageProcessor
|
| 19 |
from detoxify import Detoxify
|
| 20 |
from PIL import Image
|
| 21 |
import uvicorn
|
|
|
|
| 90 |
|
| 91 |
detoxify_model = Detoxify('multilingual')
|
| 92 |
|
| 93 |
+
# Load the NSFW image detection model and processor directly
|
| 94 |
print("Loading NSFW image classification model...")
|
| 95 |
+
nsfw_model = AutoModelForImageClassification.from_pretrained("Falconsai/nsfw_image_detection")
|
| 96 |
+
nsfw_processor = ViTImageProcessor.from_pretrained('Falconsai/nsfw_image_detection')
|
| 97 |
print("NSFW image classification model loaded.")
|
| 98 |
|
| 99 |
MODERATION_SYSTEM_PROMPT = (
|
|
|
|
| 306 |
def classify_image(image_data):
|
| 307 |
try:
|
| 308 |
img = Image.open(io.BytesIO(image_data)).convert("RGB")
|
|
|
|
|
|
|
| 309 |
|
| 310 |
+
# Use the model and processor directly as shown in the example
|
| 311 |
+
with torch.no_grad():
|
| 312 |
+
inputs = nsfw_processor(images=img, return_tensors="pt")
|
| 313 |
+
outputs = nsfw_model(**inputs)
|
| 314 |
+
logits = outputs.logits
|
| 315 |
+
|
| 316 |
+
# Get the predicted label
|
| 317 |
+
predicted_label = logits.argmax(-1).item()
|
| 318 |
+
label = nsfw_model.config.id2label[predicted_label]
|
| 319 |
+
|
| 320 |
+
# Get the confidence score
|
| 321 |
+
confidence = torch.softmax(logits, dim=-1)[0][predicted_label].item()
|
| 322 |
+
|
| 323 |
+
# Convert to our classification system
|
| 324 |
+
if label.lower() == "nsfw":
|
| 325 |
+
classification = "u"
|
| 326 |
+
nsfw_score = confidence
|
| 327 |
+
else: # normal
|
| 328 |
+
classification = "s"
|
| 329 |
+
nsfw_score = 1.0 - confidence
|
| 330 |
|
| 331 |
return {
|
| 332 |
"classification": classification,
|
| 333 |
"label": "NSFW" if classification == 'u' else "SFW",
|
| 334 |
"description": "Content may contain inappropriate or harmful material." if classification == 'u' else "Content appears to be safe and appropriate.",
|
| 335 |
+
"confidence": confidence,
|
| 336 |
"nsfw_score": nsfw_score
|
| 337 |
}
|
| 338 |
except Exception as e:
|