ALSv commited on
Commit
4fc8fe7
·
verified ·
1 Parent(s): a7b98c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -9
app.py CHANGED
@@ -40,12 +40,11 @@ model = Classifier()
40
  model.load_state_dict(torch.load("classify_nsfw_v3.0.pth", map_location="cpu"))
41
  model.eval()
42
 
43
- # ---------------- PREDICT ----------------
44
  def predict(base64_input: str):
45
  if not base64_input:
46
  return "Nessun input fornito", {}
47
 
48
- # se arriva "data:image/jpeg;base64,...."
49
  if base64_input.startswith("data:image"):
50
  base64_input = base64_input.split(",", 1)[1]
51
 
@@ -62,7 +61,7 @@ def predict(base64_input: str):
62
 
63
  return max_label, probs_dict
64
 
65
- # ---------------- HELPERS ----------------
66
  def image_to_base64(img: Image.Image):
67
  if img is None:
68
  return ""
@@ -70,9 +69,9 @@ def image_to_base64(img: Image.Image):
70
  img.save(buf, format="JPEG", quality=90)
71
  return "data:image/jpeg;base64," + base64.b64encode(buf.getvalue()).decode("utf-8")
72
 
73
- # ---------------- UI ----------------
74
  with gr.Blocks(title="NSFW Classifier") as demo:
75
- gr.Markdown("## 🎨 NSFW Image Classifier\nCarica un'immagine o incolla una stringa base64.\n\nAPI disponibile su **/api/predict**")
76
 
77
  with gr.Row():
78
  with gr.Column(scale=2):
@@ -87,10 +86,10 @@ with gr.Blocks(title="NSFW Classifier") as demo:
87
  label_output = gr.Textbox(label="Classe predetta")
88
  result_display = gr.Label(label="Distribuzione probabilità", num_top_classes=len(labels))
89
 
90
- # immagine -> converte in base64
91
  img_input.change(fn=image_to_base64, inputs=img_input, outputs=base64_input)
92
 
93
- # unico endpoint API
94
  analyze_btn.click(fn=predict,
95
  inputs=base64_input,
96
  outputs=[label_output, result_display],
@@ -98,7 +97,6 @@ with gr.Blocks(title="NSFW Classifier") as demo:
98
 
99
  clear_btn.click(fn=lambda: "", inputs=None, outputs=base64_input)
100
 
101
- # ---------------- LAUNCH ----------------
102
  if __name__ == "__main__":
103
  demo.launch()
104
-
 
40
  model.load_state_dict(torch.load("classify_nsfw_v3.0.pth", map_location="cpu"))
41
  model.eval()
42
 
43
+ # ---------------- FUNZIONE ----------------
44
  def predict(base64_input: str):
45
  if not base64_input:
46
  return "Nessun input fornito", {}
47
 
 
48
  if base64_input.startswith("data:image"):
49
  base64_input = base64_input.split(",", 1)[1]
50
 
 
61
 
62
  return max_label, probs_dict
63
 
64
+ # ---------------- HELPER ----------------
65
  def image_to_base64(img: Image.Image):
66
  if img is None:
67
  return ""
 
69
  img.save(buf, format="JPEG", quality=90)
70
  return "data:image/jpeg;base64," + base64.b64encode(buf.getvalue()).decode("utf-8")
71
 
72
+ # ---------------- INTERFACCIA ----------------
73
  with gr.Blocks(title="NSFW Classifier") as demo:
74
+ gr.Markdown("## 🎨 NSFW Image Classifier\nCarica un'immagine o incolla la stringa base64.\n\nAPI standard: **/api/predict**")
75
 
76
  with gr.Row():
77
  with gr.Column(scale=2):
 
86
  label_output = gr.Textbox(label="Classe predetta")
87
  result_display = gr.Label(label="Distribuzione probabilità", num_top_classes=len(labels))
88
 
89
+ # immagine converte in base64 → textbox
90
  img_input.change(fn=image_to_base64, inputs=img_input, outputs=base64_input)
91
 
92
+ # unico endpoint API standard
93
  analyze_btn.click(fn=predict,
94
  inputs=base64_input,
95
  outputs=[label_output, result_display],
 
97
 
98
  clear_btn.click(fn=lambda: "", inputs=None, outputs=base64_input)
99
 
100
+ # ---------------- AVVIO ----------------
101
  if __name__ == "__main__":
102
  demo.launch()