nixaut-codelabs commited on
Commit
aa976b8
·
verified ·
1 Parent(s): 5261c99

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -373
app.py CHANGED
@@ -7,8 +7,10 @@ import io
7
  import uuid
8
  import requests
9
  import numpy as np
 
10
  from typing import List, Dict, Any, Optional, Union
11
  from fastapi import FastAPI, HTTPException, Depends, Request, File, UploadFile
 
12
  from fastapi.responses import HTMLResponse, JSONResponse
13
  from fastapi.staticfiles import StaticFiles
14
  from fastapi.templating import Jinja2Templates
@@ -90,11 +92,10 @@ model.eval()
90
 
91
  detoxify_model = Detoxify('multilingual')
92
 
93
- # Load the NSFW image detection model and processor directly
94
  print("Loading NSFW image classification model...")
95
  nsfw_model = AutoModelForImageClassification.from_pretrained("Falconsai/nsfw_image_detection")
96
  nsfw_processor = ViTImageProcessor.from_pretrained('Falconsai/nsfw_image_detection')
97
- nsfw_model.eval() # Set to evaluation mode
98
  print("NSFW image classification model loaded.")
99
 
100
  MODERATION_SYSTEM_PROMPT = (
@@ -306,30 +307,24 @@ def classify_text_with_detoxify(text):
306
 
307
  def classify_image(image_data):
308
  try:
309
- # Open and convert the image
310
  img = Image.open(io.BytesIO(image_data)).convert("RGB")
311
 
312
- # Process the image with the NSFW model
313
  with torch.no_grad():
314
  inputs = nsfw_processor(images=img, return_tensors="pt")
315
- # Move to the same device as the model
316
  inputs = {k: v.to(nsfw_model.device) for k, v in inputs.items()}
317
 
318
  outputs = nsfw_model(**inputs)
319
  logits = outputs.logits
320
 
321
- # Get the predicted label
322
  predicted_label = logits.argmax(-1).item()
323
  label = nsfw_model.config.id2label[predicted_label]
324
 
325
- # Get the confidence score
326
  confidence = torch.softmax(logits, dim=-1)[0][predicted_label].item()
327
 
328
- # Convert to our classification system
329
  if label.lower() == "nsfw":
330
  classification = "u"
331
  nsfw_score = confidence
332
- else: # normal
333
  classification = "s"
334
  nsfw_score = 1.0 - confidence
335
 
@@ -350,375 +345,124 @@ def classify_image(image_data):
350
  "nsfw_score": 0.0
351
  }
352
 
353
- def process_content_item(item, text_model="gemma"):
354
- # Handle string input (simple text)
355
  if isinstance(item, str):
 
 
 
 
 
 
356
  if text_model == "gemma":
357
- gemma_result = classify_text_with_gemma(item)
358
  flagged = gemma_result["classification"] == "u"
359
-
 
 
 
 
 
 
 
 
360
  return {
361
  "flagged": flagged,
362
- "categories": {
363
- "hate": flagged,
364
- "hate/threatening": flagged,
365
- "harassment": flagged,
366
- "harassment/threatening": flagged,
367
- "self-harm": flagged,
368
- "self-harm/intent": flagged,
369
- "self-harm/instructions": flagged,
370
- "sexual": flagged,
371
- "sexual/minors": flagged,
372
- "violence": flagged,
373
- "violence/graphic": flagged,
374
- "nsfw": False
375
- },
376
- "category_scores": {
377
- "hate": 0.9 if flagged else 0.1,
378
- "hate/threatening": 0.9 if flagged else 0.1,
379
- "harassment": 0.9 if flagged else 0.1,
380
- "harassment/threatening": 0.9 if flagged else 0.1,
381
- "self-harm": 0.9 if flagged else 0.1,
382
- "self-harm/intent": 0.9 if flagged else 0.1,
383
- "self-harm/instructions": 0.9 if flagged else 0.1,
384
- "sexual": 0.9 if flagged else 0.1,
385
- "sexual/minors": 0.9 if flagged else 0.1,
386
- "violence": 0.9 if flagged else 0.1,
387
- "violence/graphic": 0.9 if flagged else 0.1,
388
- "nsfw": 0.1
389
- },
390
- "text": item
391
  }
392
-
393
  elif text_model == "detoxify":
394
- detoxify_result = classify_text_with_detoxify(item)
395
- flagged = detoxify_result["flagged"]
396
-
 
 
 
 
 
 
397
  return {
398
- "flagged": flagged,
399
- "categories": {
400
- "hate": detoxify_result["categories"].get("toxicity", False),
401
- "hate/threatening": detoxify_result["categories"].get("threat", False),
402
- "harassment": detoxify_result["categories"].get("insult", False),
403
- "harassment/threatening": detoxify_result["categories"].get("threat", False),
404
- "self-harm": False,
405
- "self-harm/intent": False,
406
- "self-harm/instructions": False,
407
- "sexual": detoxify_result["categories"].get("sexual_explicit", False),
408
- "sexual/minors": detoxify_result["categories"].get("sexual_explicit", False),
409
- "violence": detoxify_result["categories"].get("threat", False),
410
- "violence/graphic": detoxify_result["categories"].get("threat", False),
411
- "nsfw": detoxify_result["categories"].get("sexual_explicit", False)
412
- },
413
- "category_scores": {
414
- "hate": detoxify_result["category_scores"].get("toxicity", 0.1),
415
- "hate/threatening": detoxify_result["category_scores"].get("threat", 0.1),
416
- "harassment": detoxify_result["category_scores"].get("insult", 0.1),
417
- "harassment/threatening": detoxify_result["category_scores"].get("threat", 0.1),
418
- "self-harm": 0.1,
419
- "self-harm/intent": 0.1,
420
- "self-harm/instructions": 0.1,
421
- "sexual": detoxify_result["category_scores"].get("sexual_explicit", 0.1),
422
- "sexual/minors": detoxify_result["category_scores"].get("sexual_explicit", 0.1) * 0.9,
423
- "violence": detoxify_result["category_scores"].get("threat", 0.1),
424
- "violence/graphic": detoxify_result["category_scores"].get("threat", 0.1),
425
- "nsfw": detoxify_result["category_scores"].get("sexual_explicit", 0.1)
426
- },
427
- "text": item
428
  }
429
-
430
  elif text_model == "both":
431
- gemma_result = classify_text_with_gemma(item)
432
- detoxify_result = classify_text_with_detoxify(item)
433
-
434
  flagged = gemma_result["classification"] == "u" or detoxify_result["flagged"]
435
-
 
 
 
 
 
 
 
 
 
 
 
 
 
436
  return {
437
  "flagged": flagged,
438
- "categories": {
439
- "hate": flagged,
440
- "hate/threatening": flagged,
441
- "harassment": flagged,
442
- "harassment/threatening": flagged,
443
- "self-harm": flagged,
444
- "self-harm/intent": flagged,
445
- "self-harm/instructions": flagged,
446
- "sexual": flagged,
447
- "sexual/minors": flagged,
448
- "violence": flagged,
449
- "violence/graphic": flagged,
450
- "nsfw": detoxify_result["categories"].get("sexual_explicit", False)
451
- },
452
- "category_scores": {
453
- "hate": 0.9 if flagged else 0.1,
454
- "hate/threatening": 0.9 if flagged else 0.1,
455
- "harassment": 0.9 if flagged else 0.1,
456
- "harassment/threatening": 0.9 if flagged else 0.1,
457
- "self-harm": 0.9 if flagged else 0.1,
458
- "self-harm/intent": 0.9 if flagged else 0.1,
459
- "self-harm/instructions": 0.9 if flagged else 0.1,
460
- "sexual": detoxify_result["category_scores"].get("sexual_explicit", 0.1),
461
- "sexual/minors": detoxify_result["category_scores"].get("sexual_explicit", 0.1) * 0.9,
462
- "violence": 0.9 if flagged else 0.1,
463
- "violence/graphic": 0.9 if flagged else 0.1,
464
- "nsfw": detoxify_result["category_scores"].get("sexual_explicit", 0.1)
465
- },
466
- "text": item
467
  }
468
-
469
- # Handle dictionary input (structured content)
470
- elif isinstance(item, dict):
471
- content_type = item.get("type")
 
472
 
473
- if content_type == "text":
474
- text = item.get("text", "")
475
-
476
- if text_model == "gemma":
477
- gemma_result = classify_text_with_gemma(text)
478
- flagged = gemma_result["classification"] == "u"
479
-
480
- return {
481
- "flagged": flagged,
482
- "categories": {
483
- "hate": flagged,
484
- "hate/threatening": flagged,
485
- "harassment": flagged,
486
- "harassment/threatening": flagged,
487
- "self-harm": flagged,
488
- "self-harm/intent": flagged,
489
- "self-harm/instructions": flagged,
490
- "sexual": flagged,
491
- "sexual/minors": flagged,
492
- "violence": flagged,
493
- "violence/graphic": flagged,
494
- "nsfw": False
495
- },
496
- "category_scores": {
497
- "hate": 0.9 if flagged else 0.1,
498
- "hate/threatening": 0.9 if flagged else 0.1,
499
- "harassment": 0.9 if flagged else 0.1,
500
- "harassment/threatening": 0.9 if flagged else 0.1,
501
- "self-harm": 0.9 if flagged else 0.1,
502
- "self-harm/intent": 0.9 if flagged else 0.1,
503
- "self-harm/instructions": 0.9 if flagged else 0.1,
504
- "sexual": 0.9 if flagged else 0.1,
505
- "sexual/minors": 0.9 if flagged else 0.1,
506
- "violence": 0.9 if flagged else 0.1,
507
- "violence/graphic": 0.9 if flagged else 0.1,
508
- "nsfw": 0.1
509
- },
510
- "text": text
511
- }
512
-
513
- elif text_model == "detoxify":
514
- detoxify_result = classify_text_with_detoxify(text)
515
- flagged = detoxify_result["flagged"]
516
-
517
- return {
518
- "flagged": flagged,
519
- "categories": {
520
- "hate": detoxify_result["categories"].get("toxicity", False),
521
- "hate/threatening": detoxify_result["categories"].get("threat", False),
522
- "harassment": detoxify_result["categories"].get("insult", False),
523
- "harassment/threatening": detoxify_result["categories"].get("threat", False),
524
- "self-harm": False,
525
- "self-harm/intent": False,
526
- "self-harm/instructions": False,
527
- "sexual": detoxify_result["categories"].get("sexual_explicit", False),
528
- "sexual/minors": detoxify_result["categories"].get("sexual_explicit", False),
529
- "violence": detoxify_result["categories"].get("threat", False),
530
- "violence/graphic": detoxify_result["categories"].get("threat", False),
531
- "nsfw": detoxify_result["categories"].get("sexual_explicit", False)
532
- },
533
- "category_scores": {
534
- "hate": detoxify_result["category_scores"].get("toxicity", 0.1),
535
- "hate/threatening": detoxify_result["category_scores"].get("threat", 0.1),
536
- "harassment": detoxify_result["category_scores"].get("insult", 0.1),
537
- "harassment/threatening": detoxify_result["category_scores"].get("threat", 0.1),
538
- "self-harm": 0.1,
539
- "self-harm/intent": 0.1,
540
- "self-harm/instructions": 0.1,
541
- "sexual": detoxify_result["category_scores"].get("sexual_explicit", 0.1),
542
- "sexual/minors": detoxify_result["category_scores"].get("sexual_explicit", 0.1) * 0.9,
543
- "violence": detoxify_result["category_scores"].get("threat", 0.1),
544
- "violence/graphic": detoxify_result["category_scores"].get("threat", 0.1),
545
- "nsfw": detoxify_result["category_scores"].get("sexual_explicit", 0.1)
546
- },
547
- "text": text
548
- }
549
-
550
- elif text_model == "both":
551
- gemma_result = classify_text_with_gemma(text)
552
- detoxify_result = classify_text_with_detoxify(text)
553
-
554
- flagged = gemma_result["classification"] == "u" or detoxify_result["flagged"]
555
-
556
- return {
557
- "flagged": flagged,
558
- "categories": {
559
- "hate": flagged,
560
- "hate/threatening": flagged,
561
- "harassment": flagged,
562
- "harassment/threatening": flagged,
563
- "self-harm": flagged,
564
- "self-harm/intent": flagged,
565
- "self-harm/instructions": flagged,
566
- "sexual": flagged,
567
- "sexual/minors": flagged,
568
- "violence": flagged,
569
- "violence/graphic": flagged,
570
- "nsfw": detoxify_result["categories"].get("sexual_explicit", False)
571
- },
572
- "category_scores": {
573
- "hate": 0.9 if flagged else 0.1,
574
- "hate/threatening": 0.9 if flagged else 0.1,
575
- "harassment": 0.9 if flagged else 0.1,
576
- "harassment/threatening": 0.9 if flagged else 0.1,
577
- "self-harm": 0.9 if flagged else 0.1,
578
- "self-harm/intent": 0.9 if flagged else 0.1,
579
- "self-harm/instructions": 0.9 if flagged else 0.1,
580
- "sexual": detoxify_result["category_scores"].get("sexual_explicit", 0.1),
581
- "sexual/minors": detoxify_result["category_scores"].get("sexual_explicit", 0.1) * 0.9,
582
- "violence": 0.9 if flagged else 0.1,
583
- "violence/graphic": 0.9 if flagged else 0.1,
584
- "nsfw": detoxify_result["category_scores"].get("sexual_explicit", 0.1)
585
- },
586
- "text": text
587
- }
588
 
589
- elif content_type == "image":
590
- image_data = None
591
- image_url = item.get("url")
592
- image_base64 = item.get("base64")
593
-
594
- # Get image data from URL
595
- if image_url:
596
- try:
597
- response = requests.get(image_url)
598
- if response.status_code == 200:
599
- image_data = response.content
600
- else:
601
- print(f"Failed to fetch image from URL: {image_url}, status code: {response.status_code}")
602
- except Exception as e:
603
- print(f"Error fetching image from URL: {str(e)}")
604
-
605
- # Get image data from base64
606
- elif image_base64:
607
- try:
608
- if image_base64.startswith("data:image"):
609
- base64_data = image_base64.split(",")[1]
610
- else:
611
- base64_data = image_base64
612
-
613
- image_data = base64.b64decode(base64_data)
614
- except Exception as e:
615
- print(f"Error decoding base64 image: {str(e)}")
616
-
617
- # Process the image if we have data
618
- if image_data:
619
- image_result = classify_image(image_data)
620
- flagged = image_result["classification"] == "u"
621
-
622
- return {
623
- "flagged": flagged,
624
- "categories": {
625
- "hate": False,
626
- "hate/threatening": False,
627
- "harassment": False,
628
- "harassment/threatening": False,
629
- "self-harm": False,
630
- "self-harm/intent": False,
631
- "self-harm/instructions": False,
632
- "sexual": flagged,
633
- "sexual/minors": flagged,
634
- "violence": False,
635
- "violence/graphic": False,
636
- "nsfw": flagged
637
- },
638
- "category_scores": {
639
- "hate": 0.1,
640
- "hate/threatening": 0.1,
641
- "harassment": 0.1,
642
- "harassment/threatening": 0.1,
643
- "self-harm": 0.1,
644
- "self-harm/intent": 0.1,
645
- "self-harm/instructions": 0.1,
646
- "sexual": image_result["nsfw_score"],
647
- "sexual/minors": image_result["nsfw_score"] * 0.9,
648
- "violence": 0.1,
649
- "violence/graphic": 0.1,
650
- "nsfw": image_result["nsfw_score"]
651
- },
652
- "image_url": image_url,
653
- "image_base64": image_base64[:50] + "..." if image_base64 and len(image_base64) > 50 else image_base64
654
- }
655
- else:
656
- # Return error if no image data
657
- return {
658
- "flagged": False,
659
- "categories": {
660
- "hate": False,
661
- "hate/threatening": False,
662
- "harassment": False,
663
- "harassment/threatening": False,
664
- "self-harm": False,
665
- "self-harm/intent": False,
666
- "self-harm/instructions": False,
667
- "sexual": False,
668
- "sexual/minors": False,
669
- "violence": False,
670
- "violence/graphic": False,
671
- "nsfw": False
672
- },
673
- "category_scores": {
674
- "hate": 0.1,
675
- "hate/threatening": 0.1,
676
- "harassment": 0.1,
677
- "harassment/threatening": 0.1,
678
- "self-harm": 0.1,
679
- "self-harm/intent": 0.1,
680
- "self-harm/instructions": 0.1,
681
- "sexual": 0.1,
682
- "sexual/minors": 0.1,
683
- "violence": 0.1,
684
- "violence/graphic": 0.1,
685
- "nsfw": 0.1
686
- },
687
- "image_url": image_url,
688
- "image_base64": image_base64[:50] + "..." if image_base64 and len(image_base64) > 50 else image_base64
689
- }
690
 
691
- # Default return for invalid items
 
 
 
 
 
692
  return {
693
  "flagged": False,
694
- "categories": {
695
- "hate": False,
696
- "hate/threatening": False,
697
- "harassment": False,
698
- "harassment/threatening": False,
699
- "self-harm": False,
700
- "self-harm/intent": False,
701
- "self-harm/instructions": False,
702
- "sexual": False,
703
- "sexual/minors": False,
704
- "violence": False,
705
- "violence/graphic": False,
706
- "nsfw": False
707
- },
708
- "category_scores": {
709
- "hate": 0.1,
710
- "hate/threatening": 0.1,
711
- "harassment": 0.1,
712
- "harassment/threatening": 0.1,
713
- "self-harm": 0.1,
714
- "self-harm/intent": 0.1,
715
- "self-harm/instructions": 0.1,
716
- "sexual": 0.1,
717
- "sexual/minors": 0.1,
718
- "violence": 0.1,
719
- "violence/graphic": 0.1,
720
- "nsfw": 0.1
721
- }
722
  }
723
 
724
  def get_api_key(request: Request):
@@ -755,16 +499,12 @@ async def moderate_content(
755
  input_data = request.input
756
  text_model = request.model or "gemma"
757
 
758
- # Normalize input to a list of items
759
  items = []
760
-
761
  if isinstance(input_data, str):
762
- # Single string input
763
- items = [input_data]
764
  total_tokens += count_tokens(input_data)
765
  elif isinstance(input_data, list):
766
- # List of items
767
- items = input_data
768
  for item in items:
769
  if isinstance(item, str):
770
  total_tokens += count_tokens(item)
@@ -775,19 +515,16 @@ async def moderate_content(
775
 
776
  if len(items) > 10:
777
  raise HTTPException(status_code=400, detail="Too many input items. Maximum 10 allowed.")
778
-
779
- # Process each item individually
780
- results = []
781
- for item in items:
782
- result = process_content_item(item, text_model)
783
- results.append(result)
784
 
785
  response_data = {
786
  "id": f"modr_{uuid.uuid4().hex[:24]}",
787
  "object": "moderation",
788
  "created": int(time.time()),
789
  "model": text_model,
790
- "results": results
791
  }
792
 
793
  track_request_metrics(start_time, total_tokens)
@@ -801,7 +538,7 @@ async def moderate_content(
801
  async def get_metrics(api_key: str = Depends(get_api_key)):
802
  return get_performance_metrics()
803
 
804
- with open("templates/index.html", "w") as f:
805
  f.write("""<!DOCTYPE html>
806
  <html lang="en">
807
  <head>
 
7
  import uuid
8
  import requests
9
  import numpy as np
10
+ import asyncio
11
  from typing import List, Dict, Any, Optional, Union
12
  from fastapi import FastAPI, HTTPException, Depends, Request, File, UploadFile
13
+ from fastapi.concurrency import run_in_threadpool
14
  from fastapi.responses import HTMLResponse, JSONResponse
15
  from fastapi.staticfiles import StaticFiles
16
  from fastapi.templating import Jinja2Templates
 
92
 
93
  detoxify_model = Detoxify('multilingual')
94
 
 
95
  print("Loading NSFW image classification model...")
96
  nsfw_model = AutoModelForImageClassification.from_pretrained("Falconsai/nsfw_image_detection")
97
  nsfw_processor = ViTImageProcessor.from_pretrained('Falconsai/nsfw_image_detection')
98
+ nsfw_model.eval()
99
  print("NSFW image classification model loaded.")
100
 
101
  MODERATION_SYSTEM_PROMPT = (
 
307
 
308
  def classify_image(image_data):
309
  try:
 
310
  img = Image.open(io.BytesIO(image_data)).convert("RGB")
311
 
 
312
  with torch.no_grad():
313
  inputs = nsfw_processor(images=img, return_tensors="pt")
 
314
  inputs = {k: v.to(nsfw_model.device) for k, v in inputs.items()}
315
 
316
  outputs = nsfw_model(**inputs)
317
  logits = outputs.logits
318
 
 
319
  predicted_label = logits.argmax(-1).item()
320
  label = nsfw_model.config.id2label[predicted_label]
321
 
 
322
  confidence = torch.softmax(logits, dim=-1)[0][predicted_label].item()
323
 
 
324
  if label.lower() == "nsfw":
325
  classification = "u"
326
  nsfw_score = confidence
327
+ else:
328
  classification = "s"
329
  nsfw_score = 1.0 - confidence
330
 
 
345
  "nsfw_score": 0.0
346
  }
347
 
348
+ def process_content_item(item: Union[str, Dict], text_model: str = "gemma") -> Dict:
 
349
  if isinstance(item, str):
350
+ item = {"type": "text", "text": item}
351
+
352
+ content_type = item.get("type")
353
+
354
+ if content_type == "text":
355
+ text = item.get("text", "")
356
  if text_model == "gemma":
357
+ gemma_result = classify_text_with_gemma(text)
358
  flagged = gemma_result["classification"] == "u"
359
+ scores = {
360
+ "hate": 0.9 if flagged else 0.1, "hate/threatening": 0.9 if flagged else 0.1,
361
+ "harassment": 0.9 if flagged else 0.1, "harassment/threatening": 0.9 if flagged else 0.1,
362
+ "self-harm": 0.9 if flagged else 0.1, "self-harm/intent": 0.9 if flagged else 0.1,
363
+ "self-harm/instructions": 0.9 if flagged else 0.1,
364
+ "sexual": 0.9 if flagged else 0.1, "sexual/minors": 0.9 if flagged else 0.1,
365
+ "violence": 0.9 if flagged else 0.1, "violence/graphic": 0.9 if flagged else 0.1,
366
+ "nsfw": 0.1,
367
+ }
368
  return {
369
  "flagged": flagged,
370
+ "categories": {k: (v > 0.5) for k, v in scores.items()},
371
+ "category_scores": scores,
372
+ "text": text,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
  }
 
374
  elif text_model == "detoxify":
375
+ d = classify_text_with_detoxify(text)
376
+ scores = {
377
+ "hate": d["category_scores"].get("toxicity", 0.1), "hate/threatening": d["category_scores"].get("threat", 0.1),
378
+ "harassment": d["category_scores"].get("insult", 0.1), "harassment/threatening": d["category_scores"].get("threat", 0.1),
379
+ "self-harm": 0.1, "self-harm/intent": 0.1, "self-harm/instructions": 0.1,
380
+ "sexual": d["category_scores"].get("sexual_explicit", 0.1), "sexual/minors": d["category_scores"].get("sexual_explicit", 0.1),
381
+ "violence": d["category_scores"].get("threat", 0.1), "violence/graphic": d["category_scores"].get("threat", 0.1),
382
+ "nsfw": d["category_scores"].get("sexual_explicit", 0.1),
383
+ }
384
  return {
385
+ "flagged": d["flagged"],
386
+ "categories": {k: (v > 0.5) for k, v in scores.items()},
387
+ "category_scores": scores,
388
+ "text": text,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389
  }
 
390
  elif text_model == "both":
391
+ gemma_result = classify_text_with_gemma(text)
392
+ detoxify_result = classify_text_with_detoxify(text)
 
393
  flagged = gemma_result["classification"] == "u" or detoxify_result["flagged"]
394
+ scores = {
395
+ "hate": max(0.9 if gemma_result["classification"] == "u" else 0.1, detoxify_result["category_scores"].get("toxicity", 0.1)),
396
+ "hate/threatening": max(0.9 if gemma_result["classification"] == "u" else 0.1, detoxify_result["category_scores"].get("threat", 0.1)),
397
+ "harassment": max(0.9 if gemma_result["classification"] == "u" else 0.1, detoxify_result["category_scores"].get("insult", 0.1)),
398
+ "harassment/threatening": max(0.9 if gemma_result["classification"] == "u" else 0.1, detoxify_result["category_scores"].get("threat", 0.1)),
399
+ "self-harm": 0.9 if gemma_result["classification"] == "u" else 0.1,
400
+ "self-harm/intent": 0.9 if gemma_result["classification"] == "u" else 0.1,
401
+ "self-harm/instructions": 0.9 if gemma_result["classification"] == "u" else 0.1,
402
+ "sexual": max(0.9 if gemma_result["classification"] == "u" else 0.1, detoxify_result["category_scores"].get("sexual_explicit", 0.1)),
403
+ "sexual/minors": max(0.9 if gemma_result["classification"] == "u" else 0.1, detoxify_result["category_scores"].get("sexual_explicit", 0.1)),
404
+ "violence": max(0.9 if gemma_result["classification"] == "u" else 0.1, detoxify_result["category_scores"].get("threat", 0.1)),
405
+ "violence/graphic": max(0.9 if gemma_result["classification"] == "u" else 0.1, detoxify_result["category_scores"].get("threat", 0.1)),
406
+ "nsfw": detoxify_result["category_scores"].get("sexual_explicit", 0.1),
407
+ }
408
  return {
409
  "flagged": flagged,
410
+ "categories": {k: (v > 0.5) for k, v in scores.items()},
411
+ "category_scores": scores,
412
+ "text": text,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
413
  }
414
+
415
+ elif content_type == "image":
416
+ image_data = None
417
+ image_url = item.get("url")
418
+ image_base64 = item.get("base64")
419
 
420
+ if image_url:
421
+ try:
422
+ response = requests.get(image_url, timeout=10)
423
+ response.raise_for_status()
424
+ image_data = response.content
425
+ except requests.RequestException as e:
426
+ print(f"Error fetching image from URL {image_url}: {e}")
427
+ elif image_base64:
428
+ try:
429
+ if image_base64.startswith("data:image"):
430
+ image_base64 = image_base64.split(",")[1]
431
+ image_data = base64.b64decode(image_base64)
432
+ except Exception as e:
433
+ print(f"Error decoding base64 image: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434
 
435
+ if image_data:
436
+ image_result = classify_image(image_data)
437
+ flagged = image_result["classification"] == "u"
438
+ nsfw_score = image_result.get("nsfw_score", 0.1)
439
+ scores = {
440
+ "hate": 0.1, "hate/threatening": 0.1,
441
+ "harassment": 0.1, "harassment/threatening": 0.1,
442
+ "self-harm": 0.1, "self-harm/intent": 0.1, "self-harm/instructions": 0.1,
443
+ "sexual": nsfw_score, "sexual/minors": nsfw_score,
444
+ "violence": 0.1, "violence/graphic": 0.1,
445
+ "nsfw": nsfw_score,
446
+ }
447
+ return {
448
+ "flagged": flagged,
449
+ "categories": {k: (v > 0.5) for k, v in scores.items()},
450
+ "category_scores": scores,
451
+ "image_url": image_url,
452
+ "image_base64": image_base64[:50] + "..." if isinstance(image_base64, str) and len(image_base64) > 50 else None,
453
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454
 
455
+ default_scores = {
456
+ "hate": 0.1, "hate/threatening": 0.1, "harassment": 0.1, "harassment/threatening": 0.1,
457
+ "self-harm": 0.1, "self-harm/intent": 0.1, "self-harm/instructions": 0.1,
458
+ "sexual": 0.1, "sexual/minors": 0.1, "violence": 0.1, "violence/graphic": 0.1,
459
+ "nsfw": 0.1
460
+ }
461
  return {
462
  "flagged": False,
463
+ "categories": {k: False for k in default_scores},
464
+ "category_scores": default_scores,
465
+ "error": f"Invalid or unprocessable item: {item}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
  }
467
 
468
  def get_api_key(request: Request):
 
499
  input_data = request.input
500
  text_model = request.model or "gemma"
501
 
 
502
  items = []
 
503
  if isinstance(input_data, str):
504
+ items.append(input_data)
 
505
  total_tokens += count_tokens(input_data)
506
  elif isinstance(input_data, list):
507
+ items.extend(input_data)
 
508
  for item in items:
509
  if isinstance(item, str):
510
  total_tokens += count_tokens(item)
 
515
 
516
  if len(items) > 10:
517
  raise HTTPException(status_code=400, detail="Too many input items. Maximum 10 allowed.")
518
+
519
+ tasks = [run_in_threadpool(process_content_item, item, text_model) for item in items]
520
+ results = await asyncio.gather(*tasks)
 
 
 
521
 
522
  response_data = {
523
  "id": f"modr_{uuid.uuid4().hex[:24]}",
524
  "object": "moderation",
525
  "created": int(time.time()),
526
  "model": text_model,
527
+ "results": list(results)
528
  }
529
 
530
  track_request_metrics(start_time, total_tokens)
 
538
  async def get_metrics(api_key: str = Depends(get_api_key)):
539
  return get_performance_metrics()
540
 
541
+ with open("templates/index.html", "w", encoding='utf-8') as f:
542
  f.write("""<!DOCTYPE html>
543
  <html lang="en">
544
  <head>