Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -2,124 +2,90 @@ import gradio as gr
|
|
| 2 |
from PIL import Image, ImageEnhance
|
| 3 |
import numpy as np
|
| 4 |
import torch
|
| 5 |
-
from transformers import AutoProcessor, AutoModel, pipeline, ViTFeatureExtractor, ViTForImageClassification
|
|
|
|
| 6 |
|
| 7 |
-
# ------------------ Device ------------------
|
| 8 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 9 |
|
| 10 |
-
#
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
aesthetic_model.eval()
|
| 15 |
|
| 16 |
-
def
|
| 17 |
-
inputs =
|
| 18 |
with torch.no_grad():
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
-
def
|
| 26 |
enhanced = zero_dce_pipe(img_pil)
|
| 27 |
return enhanced[0]
|
| 28 |
|
| 29 |
-
#
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
cls_model = ViTForImageClassification.from_pretrained(cls_model_name).to(device)
|
| 33 |
cls_model.eval()
|
| 34 |
|
| 35 |
def classify_image(img_pil):
|
| 36 |
-
inputs =
|
| 37 |
with torch.no_grad():
|
| 38 |
logits = cls_model(**inputs).logits
|
| 39 |
-
|
| 40 |
-
label
|
| 41 |
-
return label.lower()
|
| 42 |
|
| 43 |
-
#
|
| 44 |
CATEGORY_VIBES = {
|
| 45 |
-
"person": [
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
],
|
| 49 |
-
"food": [
|
| 50 |
-
dict(name="Food Vibrant", exposure_stops=0.1, contrast=0.15, saturation=0.25, warmth=0.05, clarity=0.1),
|
| 51 |
-
dict(name="Food Natural", exposure_stops=0.05, contrast=0.05, saturation=0.1, warmth=0.02, clarity=0.05),
|
| 52 |
-
],
|
| 53 |
-
"landscape": [
|
| 54 |
-
dict(name="Landscape Punch", exposure_stops=0.1, contrast=0.2, saturation=0.2, warmth=0.05, clarity=0.15),
|
| 55 |
-
dict(name="Landscape Film", exposure_stops=0.0, contrast=0.1, saturation=0.05, warmth=0.02, clarity=0.1),
|
| 56 |
-
],
|
| 57 |
-
"default": [
|
| 58 |
-
dict(name="Pop", exposure_stops=0.05, contrast=0.2, saturation=0.2, warmth=0.05, clarity=0.1),
|
| 59 |
-
dict(name="Moody", exposure_stops=-0.05, contrast=0.15, saturation=-0.05, warmth=-0.05, clarity=0.2),
|
| 60 |
-
],
|
| 61 |
}
|
| 62 |
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
img =
|
| 66 |
-
if
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
if contrast != 0:
|
| 70 |
-
img = ImageEnhance.Contrast(img).enhance(1 + contrast)
|
| 71 |
-
if saturation != 0:
|
| 72 |
-
img = ImageEnhance.Color(img).enhance(1 + saturation)
|
| 73 |
-
if clarity != 0:
|
| 74 |
arr = np.array(img).astype(np.float32)
|
| 75 |
arr = np.clip(arr * (1 + clarity), 0, 255).astype(np.uint8)
|
| 76 |
img = Image.fromarray(arr)
|
| 77 |
-
if warmth
|
| 78 |
r, g, b = img.split()
|
| 79 |
-
r = r.point(lambda i: min(255, i
|
| 80 |
-
b = b.point(lambda i: min(255, i
|
| 81 |
-
img = Image.merge("RGB",
|
| 82 |
return img
|
| 83 |
|
| 84 |
-
# ------------------ Main Process ------------------
|
| 85 |
def process(image):
|
| 86 |
-
|
| 87 |
-
enhanced = zero_dce_enhance(image)
|
| 88 |
-
|
| 89 |
-
# Step 2: Classify image
|
| 90 |
label = classify_image(enhanced)
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
vibes = CATEGORY_VIBES["person"]
|
| 95 |
-
elif "food" in label or "dish" in label:
|
| 96 |
-
vibes = CATEGORY_VIBES["food"]
|
| 97 |
-
elif "landscape" in label or "tree" in label or "mountain" in label:
|
| 98 |
-
vibes = CATEGORY_VIBES["landscape"]
|
| 99 |
-
else:
|
| 100 |
-
vibes = CATEGORY_VIBES["default"]
|
| 101 |
-
|
| 102 |
-
# Step 3: Apply vibes + score
|
| 103 |
-
candidates = []
|
| 104 |
for vibe in vibes:
|
| 105 |
-
out = apply_adjustments(enhanced, **
|
| 106 |
-
score =
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
best
|
| 111 |
-
score, vibe_name, img_out = best
|
| 112 |
|
| 113 |
-
return img_out, f"Classified as: {label} β Chosen: {vibe_name} (score {score:.2f})"
|
| 114 |
-
|
| 115 |
-
# ------------------ UI ------------------
|
| 116 |
demo = gr.Interface(
|
| 117 |
fn=process,
|
| 118 |
inputs=gr.Image(type="pil"),
|
| 119 |
-
outputs=[gr.Image(type="pil"), gr.
|
| 120 |
-
title="
|
| 121 |
-
description="
|
| 122 |
)
|
| 123 |
-
|
| 124 |
if __name__ == "__main__":
|
| 125 |
demo.launch()
|
|
|
|
| 2 |
from PIL import Image, ImageEnhance
|
| 3 |
import numpy as np
|
| 4 |
import torch
|
| 5 |
+
from transformers import AutoProcessor, AutoModel, pipeline, ViTFeatureExtractor, ViTForImageClassification, CLIPProcessor
|
| 6 |
+
import cv2
|
| 7 |
|
|
|
|
| 8 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 9 |
|
| 10 |
+
# Aesthetic Scorer: rsinema/aesthetic-scorer (public)
|
| 11 |
+
ae_processor = CLIPProcessor.from_pretrained("rsinema/aesthetic-scorer")
|
| 12 |
+
ae_model = AutoModel.from_pretrained("rsinema/aesthetic-scorer").to(device)
|
| 13 |
+
ae_model.eval()
|
|
|
|
| 14 |
|
| 15 |
+
def aesthetic_score(img_pil):
|
| 16 |
+
inputs = ae_processor(images=img_pil, return_tensors="pt")["pixel_values"].to(device)
|
| 17 |
with torch.no_grad():
|
| 18 |
+
scores = ae_model(inputs)
|
| 19 |
+
# scores returns 7 dims; first is overall aesthetic
|
| 20 |
+
return float(scores[0][0].item())
|
| 21 |
+
|
| 22 |
+
# Enhancement using public Zero-DCE model
|
| 23 |
+
zero_dce_pipe = pipeline(
|
| 24 |
+
"image-enhancement",
|
| 25 |
+
model="nateraw/zero-dce",
|
| 26 |
+
device=0 if torch.cuda.is_available() else -1
|
| 27 |
+
)
|
| 28 |
|
| 29 |
+
def enhance_image(img_pil):
|
| 30 |
enhanced = zero_dce_pipe(img_pil)
|
| 31 |
return enhanced[0]
|
| 32 |
|
| 33 |
+
# Image Classifier (ViT)
|
| 34 |
+
cls_ext = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
|
| 35 |
+
cls_model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224").to(device)
|
|
|
|
| 36 |
cls_model.eval()
|
| 37 |
|
| 38 |
def classify_image(img_pil):
|
| 39 |
+
inputs = cls_ext(images=img_pil, return_tensors="pt").to(device)
|
| 40 |
with torch.no_grad():
|
| 41 |
logits = cls_model(**inputs).logits
|
| 42 |
+
label = cls_model.config.id2label[logits.argmax(-1).item()].lower()
|
| 43 |
+
return label
|
|
|
|
| 44 |
|
| 45 |
+
# Category-specific vibes
|
| 46 |
CATEGORY_VIBES = {
|
| 47 |
+
"person": [...], # same presets as before
|
| 48 |
+
"food": [...],
|
| 49 |
+
"landscape": [...],
|
| 50 |
+
"default": [...],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
}
|
| 52 |
|
| 53 |
+
def apply_adjustments(img, exposure, contrast, saturation, warmth, clarity):
|
| 54 |
+
img = img.convert("RGB")
|
| 55 |
+
if exposure: img = ImageEnhance.Brightness(img).enhance(2**exposure)
|
| 56 |
+
if contrast: img = ImageEnhance.Contrast(img).enhance(1 + contrast)
|
| 57 |
+
if saturation: img = ImageEnhance.Color(img).enhance(1 + saturation)
|
| 58 |
+
if clarity:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
arr = np.array(img).astype(np.float32)
|
| 60 |
arr = np.clip(arr * (1 + clarity), 0, 255).astype(np.uint8)
|
| 61 |
img = Image.fromarray(arr)
|
| 62 |
+
if warmth:
|
| 63 |
r, g, b = img.split()
|
| 64 |
+
r = r.point(lambda i: min(255, i*(1+warmth)))
|
| 65 |
+
b = b.point(lambda i: min(255, i*(1-warmth)))
|
| 66 |
+
img = Image.merge("RGB",(r,g,b))
|
| 67 |
return img
|
| 68 |
|
|
|
|
| 69 |
def process(image):
|
| 70 |
+
enhanced = enhance_image(image)
|
|
|
|
|
|
|
|
|
|
| 71 |
label = classify_image(enhanced)
|
| 72 |
+
vibes = CATEGORY_VIBES.get(label, CATEGORY_VIBES["default"])
|
| 73 |
+
|
| 74 |
+
best, best_score, best_name = None, -float("inf"), None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
for vibe in vibes:
|
| 76 |
+
out = apply_adjustments(enhanced, **vibe)
|
| 77 |
+
score = aesthetic_score(out)
|
| 78 |
+
if score > best_score:
|
| 79 |
+
best, best_score, best_name = out, score, vibe["name"]
|
| 80 |
+
|
| 81 |
+
return best, f"Classified as {label} β Chosen style: {best_name} (score {best_score:.2f})"
|
|
|
|
| 82 |
|
|
|
|
|
|
|
|
|
|
| 83 |
demo = gr.Interface(
|
| 84 |
fn=process,
|
| 85 |
inputs=gr.Image(type="pil"),
|
| 86 |
+
outputs=[gr.Image(type="pil"), gr.Text()],
|
| 87 |
+
title="Content-Aware Aesthetic AI (Public)",
|
| 88 |
+
description="Enhance β classify β apply category vibes β score with public aesthetic model"
|
| 89 |
)
|
|
|
|
| 90 |
if __name__ == "__main__":
|
| 91 |
demo.launch()
|