Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -94,6 +94,7 @@ detoxify_model = Detoxify('multilingual')
|
|
| 94 |
print("Loading NSFW image classification model...")
|
| 95 |
nsfw_model = AutoModelForImageClassification.from_pretrained("Falconsai/nsfw_image_detection")
|
| 96 |
nsfw_processor = ViTImageProcessor.from_pretrained('Falconsai/nsfw_image_detection')
|
|
|
|
| 97 |
print("NSFW image classification model loaded.")
|
| 98 |
|
| 99 |
MODERATION_SYSTEM_PROMPT = (
|
|
@@ -305,11 +306,15 @@ def classify_text_with_detoxify(text):
|
|
| 305 |
|
| 306 |
def classify_image(image_data):
|
| 307 |
try:
|
|
|
|
| 308 |
img = Image.open(io.BytesIO(image_data)).convert("RGB")
|
| 309 |
|
| 310 |
-
#
|
| 311 |
with torch.no_grad():
|
| 312 |
inputs = nsfw_processor(images=img, return_tensors="pt")
|
|
|
|
|
|
|
|
|
|
| 313 |
outputs = nsfw_model(**inputs)
|
| 314 |
logits = outputs.logits
|
| 315 |
|
|
@@ -336,6 +341,7 @@ def classify_image(image_data):
|
|
| 336 |
"nsfw_score": nsfw_score
|
| 337 |
}
|
| 338 |
except Exception as e:
|
|
|
|
| 339 |
return {
|
| 340 |
"classification": "s",
|
| 341 |
"label": "ERROR",
|
|
@@ -345,6 +351,7 @@ def classify_image(image_data):
|
|
| 345 |
}
|
| 346 |
|
| 347 |
def process_content_item(item, text_model="gemma"):
|
|
|
|
| 348 |
if isinstance(item, str):
|
| 349 |
if text_model == "gemma":
|
| 350 |
gemma_result = classify_text_with_gemma(item)
|
|
@@ -459,8 +466,11 @@ def process_content_item(item, text_model="gemma"):
|
|
| 459 |
"text": item
|
| 460 |
}
|
| 461 |
|
|
|
|
| 462 |
elif isinstance(item, dict):
|
| 463 |
-
|
|
|
|
|
|
|
| 464 |
text = item.get("text", "")
|
| 465 |
|
| 466 |
if text_model == "gemma":
|
|
@@ -576,89 +586,35 @@ def process_content_item(item, text_model="gemma"):
|
|
| 576 |
"text": text
|
| 577 |
}
|
| 578 |
|
| 579 |
-
elif
|
| 580 |
image_data = None
|
|
|
|
|
|
|
| 581 |
|
| 582 |
-
|
|
|
|
| 583 |
try:
|
| 584 |
-
response = requests.get(
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
"
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
"hate/threatening": False,
|
| 592 |
-
"harassment": False,
|
| 593 |
-
"harassment/threatening": False,
|
| 594 |
-
"self-harm": False,
|
| 595 |
-
"self-harm/intent": False,
|
| 596 |
-
"self-harm/instructions": False,
|
| 597 |
-
"sexual": False,
|
| 598 |
-
"sexual/minors": False,
|
| 599 |
-
"violence": False,
|
| 600 |
-
"violence/graphic": False,
|
| 601 |
-
"nsfw": False
|
| 602 |
-
},
|
| 603 |
-
"category_scores": {
|
| 604 |
-
"hate": 0.1,
|
| 605 |
-
"hate/threatening": 0.1,
|
| 606 |
-
"harassment": 0.1,
|
| 607 |
-
"harassment/threatening": 0.1,
|
| 608 |
-
"self-harm": 0.1,
|
| 609 |
-
"self-harm/intent": 0.1,
|
| 610 |
-
"self-harm/instructions": 0.1,
|
| 611 |
-
"sexual": 0.1,
|
| 612 |
-
"sexual/minors": 0.1,
|
| 613 |
-
"violence": 0.1,
|
| 614 |
-
"violence/graphic": 0.1,
|
| 615 |
-
"nsfw": 0.1
|
| 616 |
-
},
|
| 617 |
-
"image_url": item.get("url")
|
| 618 |
-
}
|
| 619 |
|
| 620 |
-
|
|
|
|
| 621 |
try:
|
| 622 |
-
if
|
| 623 |
-
base64_data =
|
| 624 |
else:
|
| 625 |
-
base64_data =
|
| 626 |
|
| 627 |
image_data = base64.b64decode(base64_data)
|
| 628 |
-
except Exception:
|
| 629 |
-
|
| 630 |
-
"flagged": False,
|
| 631 |
-
"categories": {
|
| 632 |
-
"hate": False,
|
| 633 |
-
"hate/threatening": False,
|
| 634 |
-
"harassment": False,
|
| 635 |
-
"harassment/threatening": False,
|
| 636 |
-
"self-harm": False,
|
| 637 |
-
"self-harm/intent": False,
|
| 638 |
-
"self-harm/instructions": False,
|
| 639 |
-
"sexual": False,
|
| 640 |
-
"sexual/minors": False,
|
| 641 |
-
"violence": False,
|
| 642 |
-
"violence/graphic": False,
|
| 643 |
-
"nsfw": False
|
| 644 |
-
},
|
| 645 |
-
"category_scores": {
|
| 646 |
-
"hate": 0.1,
|
| 647 |
-
"hate/threatening": 0.1,
|
| 648 |
-
"harassment": 0.1,
|
| 649 |
-
"harassment/threatening": 0.1,
|
| 650 |
-
"self-harm": 0.1,
|
| 651 |
-
"self-harm/intent": 0.1,
|
| 652 |
-
"self-harm/instructions": 0.1,
|
| 653 |
-
"sexual": 0.1,
|
| 654 |
-
"sexual/minors": 0.1,
|
| 655 |
-
"violence": 0.1,
|
| 656 |
-
"violence/graphic": 0.1,
|
| 657 |
-
"nsfw": 0.1
|
| 658 |
-
},
|
| 659 |
-
"image_base64": item.get("base64")[:50] + "..." if len(item.get("base64", "")) > 50 else item.get("base64", "")
|
| 660 |
-
}
|
| 661 |
|
|
|
|
| 662 |
if image_data:
|
| 663 |
image_result = classify_image(image_data)
|
| 664 |
flagged = image_result["classification"] == "u"
|
|
@@ -693,10 +649,46 @@ def process_content_item(item, text_model="gemma"):
|
|
| 693 |
"violence/graphic": 0.1,
|
| 694 |
"nsfw": image_result["nsfw_score"]
|
| 695 |
},
|
| 696 |
-
"image_url":
|
| 697 |
-
"image_base64":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 698 |
}
|
| 699 |
|
|
|
|
| 700 |
return {
|
| 701 |
"flagged": False,
|
| 702 |
"categories": {
|
|
@@ -763,10 +755,15 @@ async def moderate_content(
|
|
| 763 |
input_data = request.input
|
| 764 |
text_model = request.model or "gemma"
|
| 765 |
|
|
|
|
|
|
|
|
|
|
| 766 |
if isinstance(input_data, str):
|
|
|
|
| 767 |
items = [input_data]
|
| 768 |
total_tokens += count_tokens(input_data)
|
| 769 |
elif isinstance(input_data, list):
|
|
|
|
| 770 |
items = input_data
|
| 771 |
for item in items:
|
| 772 |
if isinstance(item, str):
|
|
@@ -779,6 +776,7 @@ async def moderate_content(
|
|
| 779 |
if len(items) > 10:
|
| 780 |
raise HTTPException(status_code=400, detail="Too many input items. Maximum 10 allowed.")
|
| 781 |
|
|
|
|
| 782 |
results = []
|
| 783 |
for item in items:
|
| 784 |
result = process_content_item(item, text_model)
|
|
|
|
| 94 |
print("Loading NSFW image classification model...")
|
| 95 |
nsfw_model = AutoModelForImageClassification.from_pretrained("Falconsai/nsfw_image_detection")
|
| 96 |
nsfw_processor = ViTImageProcessor.from_pretrained('Falconsai/nsfw_image_detection')
|
| 97 |
+
nsfw_model.eval() # Set to evaluation mode
|
| 98 |
print("NSFW image classification model loaded.")
|
| 99 |
|
| 100 |
MODERATION_SYSTEM_PROMPT = (
|
|
|
|
| 306 |
|
| 307 |
def classify_image(image_data):
|
| 308 |
try:
|
| 309 |
+
# Open and convert the image
|
| 310 |
img = Image.open(io.BytesIO(image_data)).convert("RGB")
|
| 311 |
|
| 312 |
+
# Process the image with the NSFW model
|
| 313 |
with torch.no_grad():
|
| 314 |
inputs = nsfw_processor(images=img, return_tensors="pt")
|
| 315 |
+
# Move to the same device as the model
|
| 316 |
+
inputs = {k: v.to(nsfw_model.device) for k, v in inputs.items()}
|
| 317 |
+
|
| 318 |
outputs = nsfw_model(**inputs)
|
| 319 |
logits = outputs.logits
|
| 320 |
|
|
|
|
| 341 |
"nsfw_score": nsfw_score
|
| 342 |
}
|
| 343 |
except Exception as e:
|
| 344 |
+
print(f"Error in classify_image: {str(e)}")
|
| 345 |
return {
|
| 346 |
"classification": "s",
|
| 347 |
"label": "ERROR",
|
|
|
|
| 351 |
}
|
| 352 |
|
| 353 |
def process_content_item(item, text_model="gemma"):
|
| 354 |
+
# Handle string input (simple text)
|
| 355 |
if isinstance(item, str):
|
| 356 |
if text_model == "gemma":
|
| 357 |
gemma_result = classify_text_with_gemma(item)
|
|
|
|
| 466 |
"text": item
|
| 467 |
}
|
| 468 |
|
| 469 |
+
# Handle dictionary input (structured content)
|
| 470 |
elif isinstance(item, dict):
|
| 471 |
+
content_type = item.get("type")
|
| 472 |
+
|
| 473 |
+
if content_type == "text":
|
| 474 |
text = item.get("text", "")
|
| 475 |
|
| 476 |
if text_model == "gemma":
|
|
|
|
| 586 |
"text": text
|
| 587 |
}
|
| 588 |
|
| 589 |
+
elif content_type == "image":
|
| 590 |
image_data = None
|
| 591 |
+
image_url = item.get("url")
|
| 592 |
+
image_base64 = item.get("base64")
|
| 593 |
|
| 594 |
+
# Get image data from URL
|
| 595 |
+
if image_url:
|
| 596 |
try:
|
| 597 |
+
response = requests.get(image_url)
|
| 598 |
+
if response.status_code == 200:
|
| 599 |
+
image_data = response.content
|
| 600 |
+
else:
|
| 601 |
+
print(f"Failed to fetch image from URL: {image_url}, status code: {response.status_code}")
|
| 602 |
+
except Exception as e:
|
| 603 |
+
print(f"Error fetching image from URL: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 604 |
|
| 605 |
+
# Get image data from base64
|
| 606 |
+
elif image_base64:
|
| 607 |
try:
|
| 608 |
+
if image_base64.startswith("data:image"):
|
| 609 |
+
base64_data = image_base64.split(",")[1]
|
| 610 |
else:
|
| 611 |
+
base64_data = image_base64
|
| 612 |
|
| 613 |
image_data = base64.b64decode(base64_data)
|
| 614 |
+
except Exception as e:
|
| 615 |
+
print(f"Error decoding base64 image: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 616 |
|
| 617 |
+
# Process the image if we have data
|
| 618 |
if image_data:
|
| 619 |
image_result = classify_image(image_data)
|
| 620 |
flagged = image_result["classification"] == "u"
|
|
|
|
| 649 |
"violence/graphic": 0.1,
|
| 650 |
"nsfw": image_result["nsfw_score"]
|
| 651 |
},
|
| 652 |
+
"image_url": image_url,
|
| 653 |
+
"image_base64": image_base64[:50] + "..." if image_base64 and len(image_base64) > 50 else image_base64
|
| 654 |
+
}
|
| 655 |
+
else:
|
| 656 |
+
# Return error if no image data
|
| 657 |
+
return {
|
| 658 |
+
"flagged": False,
|
| 659 |
+
"categories": {
|
| 660 |
+
"hate": False,
|
| 661 |
+
"hate/threatening": False,
|
| 662 |
+
"harassment": False,
|
| 663 |
+
"harassment/threatening": False,
|
| 664 |
+
"self-harm": False,
|
| 665 |
+
"self-harm/intent": False,
|
| 666 |
+
"self-harm/instructions": False,
|
| 667 |
+
"sexual": False,
|
| 668 |
+
"sexual/minors": False,
|
| 669 |
+
"violence": False,
|
| 670 |
+
"violence/graphic": False,
|
| 671 |
+
"nsfw": False
|
| 672 |
+
},
|
| 673 |
+
"category_scores": {
|
| 674 |
+
"hate": 0.1,
|
| 675 |
+
"hate/threatening": 0.1,
|
| 676 |
+
"harassment": 0.1,
|
| 677 |
+
"harassment/threatening": 0.1,
|
| 678 |
+
"self-harm": 0.1,
|
| 679 |
+
"self-harm/intent": 0.1,
|
| 680 |
+
"self-harm/instructions": 0.1,
|
| 681 |
+
"sexual": 0.1,
|
| 682 |
+
"sexual/minors": 0.1,
|
| 683 |
+
"violence": 0.1,
|
| 684 |
+
"violence/graphic": 0.1,
|
| 685 |
+
"nsfw": 0.1
|
| 686 |
+
},
|
| 687 |
+
"image_url": image_url,
|
| 688 |
+
"image_base64": image_base64[:50] + "..." if image_base64 and len(image_base64) > 50 else image_base64
|
| 689 |
}
|
| 690 |
|
| 691 |
+
# Default return for invalid items
|
| 692 |
return {
|
| 693 |
"flagged": False,
|
| 694 |
"categories": {
|
|
|
|
| 755 |
input_data = request.input
|
| 756 |
text_model = request.model or "gemma"
|
| 757 |
|
| 758 |
+
# Normalize input to a list of items
|
| 759 |
+
items = []
|
| 760 |
+
|
| 761 |
if isinstance(input_data, str):
|
| 762 |
+
# Single string input
|
| 763 |
items = [input_data]
|
| 764 |
total_tokens += count_tokens(input_data)
|
| 765 |
elif isinstance(input_data, list):
|
| 766 |
+
# List of items
|
| 767 |
items = input_data
|
| 768 |
for item in items:
|
| 769 |
if isinstance(item, str):
|
|
|
|
| 776 |
if len(items) > 10:
|
| 777 |
raise HTTPException(status_code=400, detail="Too many input items. Maximum 10 allowed.")
|
| 778 |
|
| 779 |
+
# Process each item individually
|
| 780 |
results = []
|
| 781 |
for item in items:
|
| 782 |
result = process_content_item(item, text_model)
|