nixaut-codelabs commited on
Commit
8cc880c
·
verified ·
1 Parent(s): 919fada

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +226 -62
app.py CHANGED
@@ -90,9 +90,10 @@ model.eval()
90
 
91
  detoxify_model = Detoxify('multilingual')
92
 
93
- # Use a Hugging Face pipeline for NSFW image detection
94
  print("Loading NSFW image classification model...")
95
- image_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection")
 
96
  print("NSFW image classification model loaded.")
97
 
98
  MODERATION_SYSTEM_PROMPT = (
@@ -174,7 +175,7 @@ class ImageContent(BaseModel):
174
 
175
  class ModerationRequest(BaseModel):
176
  input: Union[str, List[Union[str, TextContent, ImageContent]]] = Field(..., description="Content to moderate")
177
- model: Optional[str] = Field("multimodal-moderator", description="Model to use for moderation")
178
 
179
  class ModerationResponse(BaseModel):
180
  id: str
@@ -305,22 +306,27 @@ def classify_text_with_detoxify(text):
305
  def classify_image(image_data):
306
  try:
307
  img = Image.open(io.BytesIO(image_data)).convert("RGB")
308
- results = image_classifier(img)
 
309
 
310
- # Extract the top result
311
- top_result = results[0]
312
- label = top_result['label']
313
- score = top_result['score']
314
 
315
- # Map the label: 'normal' -> 's', 'nsfw' -> 'u'
316
- classification = 'u' if label == 'nsfw' else 's'
317
- nsfw_score = score if label == 'nsfw' else 1.0 - score
 
 
 
 
 
 
318
 
319
  return {
320
  "classification": classification,
321
  "label": "NSFW" if classification == 'u' else "SFW",
322
  "description": "Content may contain inappropriate or harmful material." if classification == 'u' else "Content appears to be safe and appropriate.",
323
- "confidence": score,
324
  "nsfw_score": nsfw_score
325
  }
326
  except Exception as e:
@@ -332,50 +338,85 @@ def classify_image(image_data):
332
  "nsfw_score": 0.0
333
  }
334
 
335
- def process_content_item(item):
336
  if isinstance(item, str):
337
- gemma_result = classify_text_with_gemma(item)
338
- detoxify_result = classify_text_with_detoxify(item)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339
 
340
- flagged = gemma_result["classification"] == "u" or detoxify_result["flagged"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
 
342
- return {
343
- "flagged": flagged,
344
- "categories": {
345
- "hate": flagged,
346
- "hate/threatening": flagged,
347
- "harassment": flagged,
348
- "harassment/threatening": flagged,
349
- "self-harm": flagged,
350
- "self-harm/intent": flagged,
351
- "self-harm/instructions": flagged,
352
- "sexual": flagged,
353
- "sexual/minors": flagged,
354
- "violence": flagged,
355
- "violence/graphic": flagged,
356
- "nsfw": detoxify_result["categories"].get("sexual_explicit", False)
357
- },
358
- "category_scores": {
359
- "hate": 0.9 if flagged else 0.1,
360
- "hate/threatening": 0.9 if flagged else 0.1,
361
- "harassment": 0.9 if flagged else 0.1,
362
- "harassment/threatening": 0.9 if flagged else 0.1,
363
- "self-harm": 0.9 if flagged else 0.1,
364
- "self-harm/intent": 0.9 if flagged else 0.1,
365
- "self-harm/instructions": 0.9 if flagged else 0.1,
366
- "sexual": detoxify_result["category_scores"].get("sexual_explicit", 0.1),
367
- "sexual/minors": detoxify_result["category_scores"].get("sexual_explicit", 0.1) * 0.9,
368
- "violence": 0.9 if flagged else 0.1,
369
- "violence/graphic": 0.9 if flagged else 0.1,
370
- "nsfw": detoxify_result["category_scores"].get("sexual_explicit", 0.1)
371
- },
372
- "text": item
373
- }
374
-
375
- elif isinstance(item, dict):
376
- if item.get("type") == "text":
377
- gemma_result = classify_text_with_gemma(item.get("text", ""))
378
- detoxify_result = classify_text_with_detoxify(item.get("text", ""))
379
 
380
  flagged = gemma_result["classification"] == "u" or detoxify_result["flagged"]
381
 
@@ -409,8 +450,125 @@ def process_content_item(item):
409
  "violence/graphic": 0.9 if flagged else 0.1,
410
  "nsfw": detoxify_result["category_scores"].get("sexual_explicit", 0.1)
411
  },
412
- "text": item.get("text", "")
413
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
 
415
  elif item.get("type") == "image":
416
  image_data = None
@@ -597,6 +755,7 @@ async def moderate_content(
597
 
598
  try:
599
  input_data = request.input
 
600
 
601
  if isinstance(input_data, str):
602
  items = [input_data]
@@ -616,14 +775,14 @@ async def moderate_content(
616
 
617
  results = []
618
  for item in items:
619
- result = process_content_item(item)
620
  results.append(result)
621
 
622
  response_data = {
623
  "id": f"modr_{uuid.uuid4().hex[:24]}",
624
  "object": "moderation",
625
  "created": int(time.time()),
626
- "model": request.model,
627
  "results": results
628
  }
629
 
@@ -710,9 +869,11 @@ with open("templates/index.html", "w") as f:
710
  </div>
711
  </div>
712
  <div class="mb-4">
713
- <label class="block text-sm font-medium mb-2">Model</label>
714
- <select id="modelSelect" class="w-full px-4 py-3 rounded-lg bg-white/10 border border-white/20 focus:outline-none focus:ring-2 focus:ring-indigo-400 text-white">
715
- <option value="multimodal-moderator" selected>Multimodal Moderator</option>
 
 
716
  </select>
717
  </div>
718
  <div class="mt-6">
@@ -957,6 +1118,7 @@ with open("templates/index.html", "w") as f:
957
  const loadingModal = document.getElementById('loadingModal');
958
  const mixedItemsContainer = document.getElementById('mixedItemsContainer');
959
  const addItemBtn = document.getElementById('addItemBtn');
 
960
  const exampleCards = document.querySelectorAll('.example-card');
961
 
962
  textTab.addEventListener('click', () => {
@@ -1142,6 +1304,8 @@ with open("templates/index.html", "w") as f:
1142
  return;
1143
  }
1144
 
 
 
1145
  showLoading(true);
1146
  try {
1147
  const response = await fetch('/v1/moderations', {
@@ -1152,7 +1316,7 @@ with open("templates/index.html", "w") as f:
1152
  },
1153
  body: JSON.stringify({
1154
  input: text,
1155
- model: document.getElementById('modelSelect').value
1156
  })
1157
  });
1158
 
@@ -1218,7 +1382,7 @@ with open("templates/index.html", "w") as f:
1218
  },
1219
  body: JSON.stringify({
1220
  input: [imageInput],
1221
- model: document.getElementById('modelSelect').value
1222
  })
1223
  });
1224
 
@@ -1304,7 +1468,7 @@ with open("templates/index.html", "w") as f:
1304
  },
1305
  body: JSON.stringify({
1306
  input: inputItems,
1307
- model: document.getElementById('modelSelect').value
1308
  })
1309
  });
1310
 
 
90
 
91
  detoxify_model = Detoxify('multilingual')
92
 
93
+ # Use a more accurate NSFW image detection model
94
  print("Loading NSFW image classification model...")
95
+ # Use a smaller, faster model for NSFW detection
96
+ nsfw_classifier = pipeline("image-classification", model="cafeai/nsfw-detector")
97
  print("NSFW image classification model loaded.")
98
 
99
  MODERATION_SYSTEM_PROMPT = (
 
175
 
176
  class ModerationRequest(BaseModel):
177
  input: Union[str, List[Union[str, TextContent, ImageContent]]] = Field(..., description="Content to moderate")
178
+ model: Optional[str] = Field("gemma", description="Model to use for text moderation (gemma, detoxify, both)")
179
 
180
  class ModerationResponse(BaseModel):
181
  id: str
 
306
  def classify_image(image_data):
307
  try:
308
  img = Image.open(io.BytesIO(image_data)).convert("RGB")
309
+ # Resize for faster processing
310
+ img = img.resize((224, 224))
311
 
312
+ # Use the NSFW detector
313
+ results = nsfw_classifier(img)
 
 
314
 
315
+ # Extract the NSFW score
316
+ nsfw_score = 0.0
317
+ for result in results:
318
+ if result['label'].lower() == 'nsfw':
319
+ nsfw_score = result['score']
320
+ break
321
+
322
+ # Use a threshold of 0.7 for stricter detection
323
+ classification = 'u' if nsfw_score > 0.7 else 's'
324
 
325
  return {
326
  "classification": classification,
327
  "label": "NSFW" if classification == 'u' else "SFW",
328
  "description": "Content may contain inappropriate or harmful material." if classification == 'u' else "Content appears to be safe and appropriate.",
329
+ "confidence": nsfw_score,
330
  "nsfw_score": nsfw_score
331
  }
332
  except Exception as e:
 
338
  "nsfw_score": 0.0
339
  }
340
 
341
+ def process_content_item(item, text_model="gemma"):
342
  if isinstance(item, str):
343
+ if text_model == "gemma":
344
+ gemma_result = classify_text_with_gemma(item)
345
+ flagged = gemma_result["classification"] == "u"
346
+
347
+ return {
348
+ "flagged": flagged,
349
+ "categories": {
350
+ "hate": flagged,
351
+ "hate/threatening": flagged,
352
+ "harassment": flagged,
353
+ "harassment/threatening": flagged,
354
+ "self-harm": flagged,
355
+ "self-harm/intent": flagged,
356
+ "self-harm/instructions": flagged,
357
+ "sexual": flagged,
358
+ "sexual/minors": flagged,
359
+ "violence": flagged,
360
+ "violence/graphic": flagged,
361
+ "nsfw": False
362
+ },
363
+ "category_scores": {
364
+ "hate": 0.9 if flagged else 0.1,
365
+ "hate/threatening": 0.9 if flagged else 0.1,
366
+ "harassment": 0.9 if flagged else 0.1,
367
+ "harassment/threatening": 0.9 if flagged else 0.1,
368
+ "self-harm": 0.9 if flagged else 0.1,
369
+ "self-harm/intent": 0.9 if flagged else 0.1,
370
+ "self-harm/instructions": 0.9 if flagged else 0.1,
371
+ "sexual": 0.9 if flagged else 0.1,
372
+ "sexual/minors": 0.9 if flagged else 0.1,
373
+ "violence": 0.9 if flagged else 0.1,
374
+ "violence/graphic": 0.9 if flagged else 0.1,
375
+ "nsfw": 0.1
376
+ },
377
+ "text": item
378
+ }
379
 
380
+ elif text_model == "detoxify":
381
+ detoxify_result = classify_text_with_detoxify(item)
382
+ flagged = detoxify_result["flagged"]
383
+
384
+ return {
385
+ "flagged": flagged,
386
+ "categories": {
387
+ "hate": detoxify_result["categories"].get("toxicity", False),
388
+ "hate/threatening": detoxify_result["categories"].get("threat", False),
389
+ "harassment": detoxify_result["categories"].get("insult", False),
390
+ "harassment/threatening": detoxify_result["categories"].get("threat", False),
391
+ "self-harm": False,
392
+ "self-harm/intent": False,
393
+ "self-harm/instructions": False,
394
+ "sexual": detoxify_result["categories"].get("sexual_explicit", False),
395
+ "sexual/minors": detoxify_result["categories"].get("sexual_explicit", False),
396
+ "violence": detoxify_result["categories"].get("threat", False),
397
+ "violence/graphic": detoxify_result["categories"].get("threat", False),
398
+ "nsfw": detoxify_result["categories"].get("sexual_explicit", False)
399
+ },
400
+ "category_scores": {
401
+ "hate": detoxify_result["category_scores"].get("toxicity", 0.1),
402
+ "hate/threatening": detoxify_result["category_scores"].get("threat", 0.1),
403
+ "harassment": detoxify_result["category_scores"].get("insult", 0.1),
404
+ "harassment/threatening": detoxify_result["category_scores"].get("threat", 0.1),
405
+ "self-harm": 0.1,
406
+ "self-harm/intent": 0.1,
407
+ "self-harm/instructions": 0.1,
408
+ "sexual": detoxify_result["category_scores"].get("sexual_explicit", 0.1),
409
+ "sexual/minors": detoxify_result["category_scores"].get("sexual_explicit", 0.1) * 0.9,
410
+ "violence": detoxify_result["category_scores"].get("threat", 0.1),
411
+ "violence/graphic": detoxify_result["category_scores"].get("threat", 0.1),
412
+ "nsfw": detoxify_result["category_scores"].get("sexual_explicit", 0.1)
413
+ },
414
+ "text": item
415
+ }
416
 
417
+ elif text_model == "both":
418
+ gemma_result = classify_text_with_gemma(item)
419
+ detoxify_result = classify_text_with_detoxify(item)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
 
421
  flagged = gemma_result["classification"] == "u" or detoxify_result["flagged"]
422
 
 
450
  "violence/graphic": 0.9 if flagged else 0.1,
451
  "nsfw": detoxify_result["category_scores"].get("sexual_explicit", 0.1)
452
  },
453
+ "text": item
454
  }
455
+
456
+ elif isinstance(item, dict):
457
+ if item.get("type") == "text":
458
+ text = item.get("text", "")
459
+
460
+ if text_model == "gemma":
461
+ gemma_result = classify_text_with_gemma(text)
462
+ flagged = gemma_result["classification"] == "u"
463
+
464
+ return {
465
+ "flagged": flagged,
466
+ "categories": {
467
+ "hate": flagged,
468
+ "hate/threatening": flagged,
469
+ "harassment": flagged,
470
+ "harassment/threatening": flagged,
471
+ "self-harm": flagged,
472
+ "self-harm/intent": flagged,
473
+ "self-harm/instructions": flagged,
474
+ "sexual": flagged,
475
+ "sexual/minors": flagged,
476
+ "violence": flagged,
477
+ "violence/graphic": flagged,
478
+ "nsfw": False
479
+ },
480
+ "category_scores": {
481
+ "hate": 0.9 if flagged else 0.1,
482
+ "hate/threatening": 0.9 if flagged else 0.1,
483
+ "harassment": 0.9 if flagged else 0.1,
484
+ "harassment/threatening": 0.9 if flagged else 0.1,
485
+ "self-harm": 0.9 if flagged else 0.1,
486
+ "self-harm/intent": 0.9 if flagged else 0.1,
487
+ "self-harm/instructions": 0.9 if flagged else 0.1,
488
+ "sexual": 0.9 if flagged else 0.1,
489
+ "sexual/minors": 0.9 if flagged else 0.1,
490
+ "violence": 0.9 if flagged else 0.1,
491
+ "violence/graphic": 0.9 if flagged else 0.1,
492
+ "nsfw": 0.1
493
+ },
494
+ "text": text
495
+ }
496
+
497
+ elif text_model == "detoxify":
498
+ detoxify_result = classify_text_with_detoxify(text)
499
+ flagged = detoxify_result["flagged"]
500
+
501
+ return {
502
+ "flagged": flagged,
503
+ "categories": {
504
+ "hate": detoxify_result["categories"].get("toxicity", False),
505
+ "hate/threatening": detoxify_result["categories"].get("threat", False),
506
+ "harassment": detoxify_result["categories"].get("insult", False),
507
+ "harassment/threatening": detoxify_result["categories"].get("threat", False),
508
+ "self-harm": False,
509
+ "self-harm/intent": False,
510
+ "self-harm/instructions": False,
511
+ "sexual": detoxify_result["categories"].get("sexual_explicit", False),
512
+ "sexual/minors": detoxify_result["categories"].get("sexual_explicit", False),
513
+ "violence": detoxify_result["categories"].get("threat", False),
514
+ "violence/graphic": detoxify_result["categories"].get("threat", False),
515
+ "nsfw": detoxify_result["categories"].get("sexual_explicit", False)
516
+ },
517
+ "category_scores": {
518
+ "hate": detoxify_result["category_scores"].get("toxicity", 0.1),
519
+ "hate/threatening": detoxify_result["category_scores"].get("threat", 0.1),
520
+ "harassment": detoxify_result["category_scores"].get("insult", 0.1),
521
+ "harassment/threatening": detoxify_result["category_scores"].get("threat", 0.1),
522
+ "self-harm": 0.1,
523
+ "self-harm/intent": 0.1,
524
+ "self-harm/instructions": 0.1,
525
+ "sexual": detoxify_result["category_scores"].get("sexual_explicit", 0.1),
526
+ "sexual/minors": detoxify_result["category_scores"].get("sexual_explicit", 0.1) * 0.9,
527
+ "violence": detoxify_result["category_scores"].get("threat", 0.1),
528
+ "violence/graphic": detoxify_result["category_scores"].get("threat", 0.1),
529
+ "nsfw": detoxify_result["category_scores"].get("sexual_explicit", 0.1)
530
+ },
531
+ "text": text
532
+ }
533
+
534
+ elif text_model == "both":
535
+ gemma_result = classify_text_with_gemma(text)
536
+ detoxify_result = classify_text_with_detoxify(text)
537
+
538
+ flagged = gemma_result["classification"] == "u" or detoxify_result["flagged"]
539
+
540
+ return {
541
+ "flagged": flagged,
542
+ "categories": {
543
+ "hate": flagged,
544
+ "hate/threatening": flagged,
545
+ "harassment": flagged,
546
+ "harassment/threatening": flagged,
547
+ "self-harm": flagged,
548
+ "self-harm/intent": flagged,
549
+ "self-harm/instructions": flagged,
550
+ "sexual": flagged,
551
+ "sexual/minors": flagged,
552
+ "violence": flagged,
553
+ "violence/graphic": flagged,
554
+ "nsfw": detoxify_result["categories"].get("sexual_explicit", False)
555
+ },
556
+ "category_scores": {
557
+ "hate": 0.9 if flagged else 0.1,
558
+ "hate/threatening": 0.9 if flagged else 0.1,
559
+ "harassment": 0.9 if flagged else 0.1,
560
+ "harassment/threatening": 0.9 if flagged else 0.1,
561
+ "self-harm": 0.9 if flagged else 0.1,
562
+ "self-harm/intent": 0.9 if flagged else 0.1,
563
+ "self-harm/instructions": 0.9 if flagged else 0.1,
564
+ "sexual": detoxify_result["category_scores"].get("sexual_explicit", 0.1),
565
+ "sexual/minors": detoxify_result["category_scores"].get("sexual_explicit", 0.1) * 0.9,
566
+ "violence": 0.9 if flagged else 0.1,
567
+ "violence/graphic": 0.9 if flagged else 0.1,
568
+ "nsfw": detoxify_result["category_scores"].get("sexual_explicit", 0.1)
569
+ },
570
+ "text": text
571
+ }
572
 
573
  elif item.get("type") == "image":
574
  image_data = None
 
755
 
756
  try:
757
  input_data = request.input
758
+ text_model = request.model or "gemma"
759
 
760
  if isinstance(input_data, str):
761
  items = [input_data]
 
775
 
776
  results = []
777
  for item in items:
778
+ result = process_content_item(item, text_model)
779
  results.append(result)
780
 
781
  response_data = {
782
  "id": f"modr_{uuid.uuid4().hex[:24]}",
783
  "object": "moderation",
784
  "created": int(time.time()),
785
+ "model": text_model,
786
  "results": results
787
  }
788
 
 
869
  </div>
870
  </div>
871
  <div class="mb-4">
872
+ <label class="block text-sm font-medium mb-2">Text Model</label>
873
+ <select id="textModelSelect" class="w-full px-4 py-3 rounded-lg bg-white/10 border border-white/20 focus:outline-none focus:ring-2 focus:ring-indigo-400 text-white">
874
+ <option value="gemma">Gemma (Fast)</option>
875
+ <option value="detoxify">Detoxify (Detailed)</option>
876
+ <option value="both">Both (Most Accurate)</option>
877
  </select>
878
  </div>
879
  <div class="mt-6">
 
1118
  const loadingModal = document.getElementById('loadingModal');
1119
  const mixedItemsContainer = document.getElementById('mixedItemsContainer');
1120
  const addItemBtn = document.getElementById('addItemBtn');
1121
+ const textModelSelect = document.getElementById('textModelSelect');
1122
  const exampleCards = document.querySelectorAll('.example-card');
1123
 
1124
  textTab.addEventListener('click', () => {
 
1304
  return;
1305
  }
1306
 
1307
+ const textModel = textModelSelect.value;
1308
+
1309
  showLoading(true);
1310
  try {
1311
  const response = await fetch('/v1/moderations', {
 
1316
  },
1317
  body: JSON.stringify({
1318
  input: text,
1319
+ model: textModel
1320
  })
1321
  });
1322
 
 
1382
  },
1383
  body: JSON.stringify({
1384
  input: [imageInput],
1385
+ model: textModelSelect.value
1386
  })
1387
  });
1388
 
 
1468
  },
1469
  body: JSON.stringify({
1470
  input: inputItems,
1471
+ model: textModelSelect.value
1472
  })
1473
  });
1474