nixaut-codelabs commited on
Commit
5261c99
·
verified ·
1 Parent(s): c500ef1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -78
app.py CHANGED
@@ -94,6 +94,7 @@ detoxify_model = Detoxify('multilingual')
94
  print("Loading NSFW image classification model...")
95
  nsfw_model = AutoModelForImageClassification.from_pretrained("Falconsai/nsfw_image_detection")
96
  nsfw_processor = ViTImageProcessor.from_pretrained('Falconsai/nsfw_image_detection')
 
97
  print("NSFW image classification model loaded.")
98
 
99
  MODERATION_SYSTEM_PROMPT = (
@@ -305,11 +306,15 @@ def classify_text_with_detoxify(text):
305
 
306
  def classify_image(image_data):
307
  try:
 
308
  img = Image.open(io.BytesIO(image_data)).convert("RGB")
309
 
310
- # Use the model and processor directly as shown in the example
311
  with torch.no_grad():
312
  inputs = nsfw_processor(images=img, return_tensors="pt")
 
 
 
313
  outputs = nsfw_model(**inputs)
314
  logits = outputs.logits
315
 
@@ -336,6 +341,7 @@ def classify_image(image_data):
336
  "nsfw_score": nsfw_score
337
  }
338
  except Exception as e:
 
339
  return {
340
  "classification": "s",
341
  "label": "ERROR",
@@ -345,6 +351,7 @@ def classify_image(image_data):
345
  }
346
 
347
  def process_content_item(item, text_model="gemma"):
 
348
  if isinstance(item, str):
349
  if text_model == "gemma":
350
  gemma_result = classify_text_with_gemma(item)
@@ -459,8 +466,11 @@ def process_content_item(item, text_model="gemma"):
459
  "text": item
460
  }
461
 
 
462
  elif isinstance(item, dict):
463
- if item.get("type") == "text":
 
 
464
  text = item.get("text", "")
465
 
466
  if text_model == "gemma":
@@ -576,89 +586,35 @@ def process_content_item(item, text_model="gemma"):
576
  "text": text
577
  }
578
 
579
- elif item.get("type") == "image":
580
  image_data = None
 
 
581
 
582
- if item.get("url"):
 
583
  try:
584
- response = requests.get(item.get("url"))
585
- image_data = response.content
586
- except Exception:
587
- return {
588
- "flagged": False,
589
- "categories": {
590
- "hate": False,
591
- "hate/threatening": False,
592
- "harassment": False,
593
- "harassment/threatening": False,
594
- "self-harm": False,
595
- "self-harm/intent": False,
596
- "self-harm/instructions": False,
597
- "sexual": False,
598
- "sexual/minors": False,
599
- "violence": False,
600
- "violence/graphic": False,
601
- "nsfw": False
602
- },
603
- "category_scores": {
604
- "hate": 0.1,
605
- "hate/threatening": 0.1,
606
- "harassment": 0.1,
607
- "harassment/threatening": 0.1,
608
- "self-harm": 0.1,
609
- "self-harm/intent": 0.1,
610
- "self-harm/instructions": 0.1,
611
- "sexual": 0.1,
612
- "sexual/minors": 0.1,
613
- "violence": 0.1,
614
- "violence/graphic": 0.1,
615
- "nsfw": 0.1
616
- },
617
- "image_url": item.get("url")
618
- }
619
 
620
- elif item.get("base64"):
 
621
  try:
622
- if item.get("base64").startswith("data:image"):
623
- base64_data = item.get("base64").split(",")[1]
624
  else:
625
- base64_data = item.get("base64")
626
 
627
  image_data = base64.b64decode(base64_data)
628
- except Exception:
629
- return {
630
- "flagged": False,
631
- "categories": {
632
- "hate": False,
633
- "hate/threatening": False,
634
- "harassment": False,
635
- "harassment/threatening": False,
636
- "self-harm": False,
637
- "self-harm/intent": False,
638
- "self-harm/instructions": False,
639
- "sexual": False,
640
- "sexual/minors": False,
641
- "violence": False,
642
- "violence/graphic": False,
643
- "nsfw": False
644
- },
645
- "category_scores": {
646
- "hate": 0.1,
647
- "hate/threatening": 0.1,
648
- "harassment": 0.1,
649
- "harassment/threatening": 0.1,
650
- "self-harm": 0.1,
651
- "self-harm/intent": 0.1,
652
- "self-harm/instructions": 0.1,
653
- "sexual": 0.1,
654
- "sexual/minors": 0.1,
655
- "violence": 0.1,
656
- "violence/graphic": 0.1,
657
- "nsfw": 0.1
658
- },
659
- "image_base64": item.get("base64")[:50] + "..." if len(item.get("base64", "")) > 50 else item.get("base64", "")
660
- }
661
 
 
662
  if image_data:
663
  image_result = classify_image(image_data)
664
  flagged = image_result["classification"] == "u"
@@ -693,10 +649,46 @@ def process_content_item(item, text_model="gemma"):
693
  "violence/graphic": 0.1,
694
  "nsfw": image_result["nsfw_score"]
695
  },
696
- "image_url": item.get("url"),
697
- "image_base64": item.get("base64")[:50] + "..." if item.get("base64") and len(item.get("base64", "")) > 50 else item.get("base64", "")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
698
  }
699
 
 
700
  return {
701
  "flagged": False,
702
  "categories": {
@@ -763,10 +755,15 @@ async def moderate_content(
763
  input_data = request.input
764
  text_model = request.model or "gemma"
765
 
 
 
 
766
  if isinstance(input_data, str):
 
767
  items = [input_data]
768
  total_tokens += count_tokens(input_data)
769
  elif isinstance(input_data, list):
 
770
  items = input_data
771
  for item in items:
772
  if isinstance(item, str):
@@ -779,6 +776,7 @@ async def moderate_content(
779
  if len(items) > 10:
780
  raise HTTPException(status_code=400, detail="Too many input items. Maximum 10 allowed.")
781
 
 
782
  results = []
783
  for item in items:
784
  result = process_content_item(item, text_model)
 
94
  print("Loading NSFW image classification model...")
95
  nsfw_model = AutoModelForImageClassification.from_pretrained("Falconsai/nsfw_image_detection")
96
  nsfw_processor = ViTImageProcessor.from_pretrained('Falconsai/nsfw_image_detection')
97
+ nsfw_model.eval() # Set to evaluation mode
98
  print("NSFW image classification model loaded.")
99
 
100
  MODERATION_SYSTEM_PROMPT = (
 
306
 
307
  def classify_image(image_data):
308
  try:
309
+ # Open and convert the image
310
  img = Image.open(io.BytesIO(image_data)).convert("RGB")
311
 
312
+ # Process the image with the NSFW model
313
  with torch.no_grad():
314
  inputs = nsfw_processor(images=img, return_tensors="pt")
315
+ # Move to the same device as the model
316
+ inputs = {k: v.to(nsfw_model.device) for k, v in inputs.items()}
317
+
318
  outputs = nsfw_model(**inputs)
319
  logits = outputs.logits
320
 
 
341
  "nsfw_score": nsfw_score
342
  }
343
  except Exception as e:
344
+ print(f"Error in classify_image: {str(e)}")
345
  return {
346
  "classification": "s",
347
  "label": "ERROR",
 
351
  }
352
 
353
  def process_content_item(item, text_model="gemma"):
354
+ # Handle string input (simple text)
355
  if isinstance(item, str):
356
  if text_model == "gemma":
357
  gemma_result = classify_text_with_gemma(item)
 
466
  "text": item
467
  }
468
 
469
+ # Handle dictionary input (structured content)
470
  elif isinstance(item, dict):
471
+ content_type = item.get("type")
472
+
473
+ if content_type == "text":
474
  text = item.get("text", "")
475
 
476
  if text_model == "gemma":
 
586
  "text": text
587
  }
588
 
589
+ elif content_type == "image":
590
  image_data = None
591
+ image_url = item.get("url")
592
+ image_base64 = item.get("base64")
593
 
594
+ # Get image data from URL
595
+ if image_url:
596
  try:
597
+ response = requests.get(image_url)
598
+ if response.status_code == 200:
599
+ image_data = response.content
600
+ else:
601
+ print(f"Failed to fetch image from URL: {image_url}, status code: {response.status_code}")
602
+ except Exception as e:
603
+ print(f"Error fetching image from URL: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
604
 
605
+ # Get image data from base64
606
+ elif image_base64:
607
  try:
608
+ if image_base64.startswith("data:image"):
609
+ base64_data = image_base64.split(",")[1]
610
  else:
611
+ base64_data = image_base64
612
 
613
  image_data = base64.b64decode(base64_data)
614
+ except Exception as e:
615
+ print(f"Error decoding base64 image: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
616
 
617
+ # Process the image if we have data
618
  if image_data:
619
  image_result = classify_image(image_data)
620
  flagged = image_result["classification"] == "u"
 
649
  "violence/graphic": 0.1,
650
  "nsfw": image_result["nsfw_score"]
651
  },
652
+ "image_url": image_url,
653
+ "image_base64": image_base64[:50] + "..." if image_base64 and len(image_base64) > 50 else image_base64
654
+ }
655
+ else:
656
+ # Return error if no image data
657
+ return {
658
+ "flagged": False,
659
+ "categories": {
660
+ "hate": False,
661
+ "hate/threatening": False,
662
+ "harassment": False,
663
+ "harassment/threatening": False,
664
+ "self-harm": False,
665
+ "self-harm/intent": False,
666
+ "self-harm/instructions": False,
667
+ "sexual": False,
668
+ "sexual/minors": False,
669
+ "violence": False,
670
+ "violence/graphic": False,
671
+ "nsfw": False
672
+ },
673
+ "category_scores": {
674
+ "hate": 0.1,
675
+ "hate/threatening": 0.1,
676
+ "harassment": 0.1,
677
+ "harassment/threatening": 0.1,
678
+ "self-harm": 0.1,
679
+ "self-harm/intent": 0.1,
680
+ "self-harm/instructions": 0.1,
681
+ "sexual": 0.1,
682
+ "sexual/minors": 0.1,
683
+ "violence": 0.1,
684
+ "violence/graphic": 0.1,
685
+ "nsfw": 0.1
686
+ },
687
+ "image_url": image_url,
688
+ "image_base64": image_base64[:50] + "..." if image_base64 and len(image_base64) > 50 else image_base64
689
  }
690
 
691
+ # Default return for invalid items
692
  return {
693
  "flagged": False,
694
  "categories": {
 
755
  input_data = request.input
756
  text_model = request.model or "gemma"
757
 
758
+ # Normalize input to a list of items
759
+ items = []
760
+
761
  if isinstance(input_data, str):
762
+ # Single string input
763
  items = [input_data]
764
  total_tokens += count_tokens(input_data)
765
  elif isinstance(input_data, list):
766
+ # List of items
767
  items = input_data
768
  for item in items:
769
  if isinstance(item, str):
 
776
  if len(items) > 10:
777
  raise HTTPException(status_code=400, detail="Too many input items. Maximum 10 allowed.")
778
 
779
+ # Process each item individually
780
  results = []
781
  for item in items:
782
  result = process_content_item(item, text_model)