nixaut-codelabs commited on
Commit
c500ef1
·
verified ·
1 Parent(s): b8ddba8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -15
app.py CHANGED
@@ -15,7 +15,7 @@ from fastapi.templating import Jinja2Templates
15
  from pydantic import BaseModel, Field
16
  from dotenv import load_dotenv
17
  from huggingface_hub import snapshot_download
18
- from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, pipeline
19
  from detoxify import Detoxify
20
  from PIL import Image
21
  import uvicorn
@@ -90,9 +90,10 @@ model.eval()
90
 
91
  detoxify_model = Detoxify('multilingual')
92
 
93
- # Use the Falconsai NSFW image detection model
94
  print("Loading NSFW image classification model...")
95
- nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection")
 
96
  print("NSFW image classification model loaded.")
97
 
98
  MODERATION_SYSTEM_PROMPT = (
@@ -305,24 +306,33 @@ def classify_text_with_detoxify(text):
305
  def classify_image(image_data):
306
  try:
307
  img = Image.open(io.BytesIO(image_data)).convert("RGB")
308
- # Use the Falconsai NSFW detector
309
- results = nsfw_classifier(img)
310
 
311
- # Extract the NSFW score
312
- nsfw_score = 0.0
313
- for result in results:
314
- if result['label'].lower() == 'nsfw':
315
- nsfw_score = result['score']
316
- break
317
-
318
- # Use a threshold of 0.5 for NSFW detection
319
- classification = 'u' if nsfw_score > 0.5 else 's'
 
 
 
 
 
 
 
 
 
 
 
320
 
321
  return {
322
  "classification": classification,
323
  "label": "NSFW" if classification == 'u' else "SFW",
324
  "description": "Content may contain inappropriate or harmful material." if classification == 'u' else "Content appears to be safe and appropriate.",
325
- "confidence": nsfw_score,
326
  "nsfw_score": nsfw_score
327
  }
328
  except Exception as e:
 
15
  from pydantic import BaseModel, Field
16
  from dotenv import load_dotenv
17
  from huggingface_hub import snapshot_download
18
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, AutoModelForImageClassification, ViTImageProcessor
19
  from detoxify import Detoxify
20
  from PIL import Image
21
  import uvicorn
 
90
 
91
  detoxify_model = Detoxify('multilingual')
92
 
93
+ # Load the NSFW image detection model and processor directly
94
  print("Loading NSFW image classification model...")
95
+ nsfw_model = AutoModelForImageClassification.from_pretrained("Falconsai/nsfw_image_detection")
96
+ nsfw_processor = ViTImageProcessor.from_pretrained('Falconsai/nsfw_image_detection')
97
  print("NSFW image classification model loaded.")
98
 
99
  MODERATION_SYSTEM_PROMPT = (
 
306
  def classify_image(image_data):
307
  try:
308
  img = Image.open(io.BytesIO(image_data)).convert("RGB")
 
 
309
 
310
+ # Use the model and processor directly as shown in the example
311
+ with torch.no_grad():
312
+ inputs = nsfw_processor(images=img, return_tensors="pt")
313
+ outputs = nsfw_model(**inputs)
314
+ logits = outputs.logits
315
+
316
+ # Get the predicted label
317
+ predicted_label = logits.argmax(-1).item()
318
+ label = nsfw_model.config.id2label[predicted_label]
319
+
320
+ # Get the confidence score
321
+ confidence = torch.softmax(logits, dim=-1)[0][predicted_label].item()
322
+
323
+ # Convert to our classification system
324
+ if label.lower() == "nsfw":
325
+ classification = "u"
326
+ nsfw_score = confidence
327
+ else: # normal
328
+ classification = "s"
329
+ nsfw_score = 1.0 - confidence
330
 
331
  return {
332
  "classification": classification,
333
  "label": "NSFW" if classification == 'u' else "SFW",
334
  "description": "Content may contain inappropriate or harmful material." if classification == 'u' else "Content appears to be safe and appropriate.",
335
+ "confidence": confidence,
336
  "nsfw_score": nsfw_score
337
  }
338
  except Exception as e: