Update app.py
Browse files
app.py
CHANGED
|
@@ -40,12 +40,11 @@ model = Classifier()
|
|
| 40 |
model.load_state_dict(torch.load("classify_nsfw_v3.0.pth", map_location="cpu"))
|
| 41 |
model.eval()
|
| 42 |
|
| 43 |
-
# ----------------
|
| 44 |
def predict(base64_input: str):
|
| 45 |
if not base64_input:
|
| 46 |
return "Nessun input fornito", {}
|
| 47 |
|
| 48 |
-
# se arriva "data:image/jpeg;base64,...."
|
| 49 |
if base64_input.startswith("data:image"):
|
| 50 |
base64_input = base64_input.split(",", 1)[1]
|
| 51 |
|
|
@@ -62,7 +61,7 @@ def predict(base64_input: str):
|
|
| 62 |
|
| 63 |
return max_label, probs_dict
|
| 64 |
|
| 65 |
-
# ----------------
|
| 66 |
def image_to_base64(img: Image.Image):
|
| 67 |
if img is None:
|
| 68 |
return ""
|
|
@@ -70,9 +69,9 @@ def image_to_base64(img: Image.Image):
|
|
| 70 |
img.save(buf, format="JPEG", quality=90)
|
| 71 |
return "data:image/jpeg;base64," + base64.b64encode(buf.getvalue()).decode("utf-8")
|
| 72 |
|
| 73 |
-
# ----------------
|
| 74 |
with gr.Blocks(title="NSFW Classifier") as demo:
|
| 75 |
-
gr.Markdown("## 🎨 NSFW Image Classifier\nCarica un'immagine o incolla
|
| 76 |
|
| 77 |
with gr.Row():
|
| 78 |
with gr.Column(scale=2):
|
|
@@ -87,10 +86,10 @@ with gr.Blocks(title="NSFW Classifier") as demo:
|
|
| 87 |
label_output = gr.Textbox(label="Classe predetta")
|
| 88 |
result_display = gr.Label(label="Distribuzione probabilità", num_top_classes=len(labels))
|
| 89 |
|
| 90 |
-
# immagine
|
| 91 |
img_input.change(fn=image_to_base64, inputs=img_input, outputs=base64_input)
|
| 92 |
|
| 93 |
-
# unico endpoint API
|
| 94 |
analyze_btn.click(fn=predict,
|
| 95 |
inputs=base64_input,
|
| 96 |
outputs=[label_output, result_display],
|
|
@@ -98,7 +97,6 @@ with gr.Blocks(title="NSFW Classifier") as demo:
|
|
| 98 |
|
| 99 |
clear_btn.click(fn=lambda: "", inputs=None, outputs=base64_input)
|
| 100 |
|
| 101 |
-
# ----------------
|
| 102 |
if __name__ == "__main__":
|
| 103 |
demo.launch()
|
| 104 |
-
|
|
|
|
| 40 |
model.load_state_dict(torch.load("classify_nsfw_v3.0.pth", map_location="cpu"))
|
| 41 |
model.eval()
|
| 42 |
|
| 43 |
+
# ---------------- FUNZIONE ----------------
|
| 44 |
def predict(base64_input: str):
|
| 45 |
if not base64_input:
|
| 46 |
return "Nessun input fornito", {}
|
| 47 |
|
|
|
|
| 48 |
if base64_input.startswith("data:image"):
|
| 49 |
base64_input = base64_input.split(",", 1)[1]
|
| 50 |
|
|
|
|
| 61 |
|
| 62 |
return max_label, probs_dict
|
| 63 |
|
| 64 |
+
# ---------------- HELPER ----------------
|
| 65 |
def image_to_base64(img: Image.Image):
|
| 66 |
if img is None:
|
| 67 |
return ""
|
|
|
|
| 69 |
img.save(buf, format="JPEG", quality=90)
|
| 70 |
return "data:image/jpeg;base64," + base64.b64encode(buf.getvalue()).decode("utf-8")
|
| 71 |
|
| 72 |
+
# ---------------- INTERFACCIA ----------------
|
| 73 |
with gr.Blocks(title="NSFW Classifier") as demo:
|
| 74 |
+
gr.Markdown("## 🎨 NSFW Image Classifier\nCarica un'immagine o incolla la stringa base64.\n\nAPI standard: **/api/predict**")
|
| 75 |
|
| 76 |
with gr.Row():
|
| 77 |
with gr.Column(scale=2):
|
|
|
|
| 86 |
label_output = gr.Textbox(label="Classe predetta")
|
| 87 |
result_display = gr.Label(label="Distribuzione probabilità", num_top_classes=len(labels))
|
| 88 |
|
| 89 |
+
# immagine → converte in base64 → textbox
|
| 90 |
img_input.change(fn=image_to_base64, inputs=img_input, outputs=base64_input)
|
| 91 |
|
| 92 |
+
# unico endpoint API standard
|
| 93 |
analyze_btn.click(fn=predict,
|
| 94 |
inputs=base64_input,
|
| 95 |
outputs=[label_output, result_display],
|
|
|
|
| 97 |
|
| 98 |
clear_btn.click(fn=lambda: "", inputs=None, outputs=base64_input)
|
| 99 |
|
| 100 |
+
# ---------------- AVVIO ----------------
|
| 101 |
if __name__ == "__main__":
|
| 102 |
demo.launch()
|
|
|