Mavthunder commited on
Commit
cf39539
Β·
verified Β·
1 Parent(s): 5fe3046

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -87
app.py CHANGED
@@ -2,124 +2,90 @@ import gradio as gr
2
  from PIL import Image, ImageEnhance
3
  import numpy as np
4
  import torch
5
- from transformers import AutoProcessor, AutoModel, pipeline, ViTFeatureExtractor, ViTForImageClassification
 
6
 
7
- # ------------------ Device ------------------
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
 
10
- # ------------------ Aesthetic Predictor ------------------
11
- predictor_name = "shunk031/aesthetic-predictor-v2"
12
- processor = AutoProcessor.from_pretrained(predictor_name)
13
- aesthetic_model = AutoModel.from_pretrained(predictor_name).to(device)
14
- aesthetic_model.eval()
15
 
16
- def aesthetic_score_ai(img_pil):
17
- inputs = processor(images=img_pil, return_tensors="pt").to(device)
18
  with torch.no_grad():
19
- outputs = aesthetic_model(**inputs)
20
- return float(outputs.logits.mean().item())
21
-
22
- # ------------------ Zero-DCE (public) ------------------
23
- zero_dce_pipe = pipeline("image-enhancement", model="nateraw/zero-dce", device=0 if torch.cuda.is_available() else -1)
 
 
 
 
 
24
 
25
- def zero_dce_enhance(img_pil):
26
  enhanced = zero_dce_pipe(img_pil)
27
  return enhanced[0]
28
 
29
- # ------------------ Image Classifier ------------------
30
- cls_model_name = "google/vit-base-patch16-224"
31
- cls_extractor = ViTFeatureExtractor.from_pretrained(cls_model_name)
32
- cls_model = ViTForImageClassification.from_pretrained(cls_model_name).to(device)
33
  cls_model.eval()
34
 
35
  def classify_image(img_pil):
36
- inputs = cls_extractor(images=img_pil, return_tensors="pt").to(device)
37
  with torch.no_grad():
38
  logits = cls_model(**inputs).logits
39
- pred = logits.argmax(-1).item()
40
- label = cls_model.config.id2label[pred]
41
- return label.lower()
42
 
43
- # ------------------ Vibes per Category ------------------
44
  CATEGORY_VIBES = {
45
- "person": [
46
- dict(name="Portrait Soft", exposure_stops=0.1, contrast=0.05, saturation=0.05, warmth=0.1, clarity=-0.05),
47
- dict(name="Portrait Glow", exposure_stops=0.15, contrast=-0.02, saturation=0.1, warmth=0.08, clarity=0.0),
48
- ],
49
- "food": [
50
- dict(name="Food Vibrant", exposure_stops=0.1, contrast=0.15, saturation=0.25, warmth=0.05, clarity=0.1),
51
- dict(name="Food Natural", exposure_stops=0.05, contrast=0.05, saturation=0.1, warmth=0.02, clarity=0.05),
52
- ],
53
- "landscape": [
54
- dict(name="Landscape Punch", exposure_stops=0.1, contrast=0.2, saturation=0.2, warmth=0.05, clarity=0.15),
55
- dict(name="Landscape Film", exposure_stops=0.0, contrast=0.1, saturation=0.05, warmth=0.02, clarity=0.1),
56
- ],
57
- "default": [
58
- dict(name="Pop", exposure_stops=0.05, contrast=0.2, saturation=0.2, warmth=0.05, clarity=0.1),
59
- dict(name="Moody", exposure_stops=-0.05, contrast=0.15, saturation=-0.05, warmth=-0.05, clarity=0.2),
60
- ],
61
  }
62
 
63
- # ------------------ Image Adjustment ------------------
64
- def apply_adjustments(img_pil, exposure_stops=0.0, contrast=0.0, saturation=0.0, warmth=0.0, clarity=0.0):
65
- img = img_pil.convert("RGB")
66
- if exposure_stops != 0:
67
- factor = 2.0 ** exposure_stops
68
- img = ImageEnhance.Brightness(img).enhance(factor)
69
- if contrast != 0:
70
- img = ImageEnhance.Contrast(img).enhance(1 + contrast)
71
- if saturation != 0:
72
- img = ImageEnhance.Color(img).enhance(1 + saturation)
73
- if clarity != 0:
74
  arr = np.array(img).astype(np.float32)
75
  arr = np.clip(arr * (1 + clarity), 0, 255).astype(np.uint8)
76
  img = Image.fromarray(arr)
77
- if warmth != 0:
78
  r, g, b = img.split()
79
- r = r.point(lambda i: min(255, i * (1 + warmth)))
80
- b = b.point(lambda i: min(255, i * (1 - warmth)))
81
- img = Image.merge("RGB", (r, g, b))
82
  return img
83
 
84
- # ------------------ Main Process ------------------
85
  def process(image):
86
- # Step 1: Enhance with Zero-DCE
87
- enhanced = zero_dce_enhance(image)
88
-
89
- # Step 2: Classify image
90
  label = classify_image(enhanced)
91
-
92
- # Pick vibes for category
93
- if "person" in label or "face" in label:
94
- vibes = CATEGORY_VIBES["person"]
95
- elif "food" in label or "dish" in label:
96
- vibes = CATEGORY_VIBES["food"]
97
- elif "landscape" in label or "tree" in label or "mountain" in label:
98
- vibes = CATEGORY_VIBES["landscape"]
99
- else:
100
- vibes = CATEGORY_VIBES["default"]
101
-
102
- # Step 3: Apply vibes + score
103
- candidates = []
104
  for vibe in vibes:
105
- out = apply_adjustments(enhanced, **{k: v for k, v in vibe.items() if k != "name"})
106
- score = aesthetic_score_ai(out)
107
- candidates.append((score, vibe["name"], out))
108
-
109
- # Step 4: Pick best
110
- best = max(candidates, key=lambda x: x[0])
111
- score, vibe_name, img_out = best
112
 
113
- return img_out, f"Classified as: {label} β†’ Chosen: {vibe_name} (score {score:.2f})"
114
-
115
- # ------------------ UI ------------------
116
  demo = gr.Interface(
117
  fn=process,
118
  inputs=gr.Image(type="pil"),
119
- outputs=[gr.Image(type="pil"), gr.Textbox()],
120
- title="AI Aesthetic Photo Enhancer",
121
- description="Uploads β†’ Enhance (Zero-DCE) β†’ Classify β†’ Apply vibes β†’ Pick most aesthetic"
122
  )
123
-
124
  if __name__ == "__main__":
125
  demo.launch()
 
2
  from PIL import Image, ImageEnhance
3
  import numpy as np
4
  import torch
5
+ from transformers import AutoProcessor, AutoModel, pipeline, ViTFeatureExtractor, ViTForImageClassification, CLIPProcessor
6
+ import cv2
7
 
 
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
 
10
+ # Aesthetic Scorer: rsinema/aesthetic-scorer (public)
11
+ ae_processor = CLIPProcessor.from_pretrained("rsinema/aesthetic-scorer")
12
+ ae_model = AutoModel.from_pretrained("rsinema/aesthetic-scorer").to(device)
13
+ ae_model.eval()
 
14
 
15
+ def aesthetic_score(img_pil):
16
+ inputs = ae_processor(images=img_pil, return_tensors="pt")["pixel_values"].to(device)
17
  with torch.no_grad():
18
+ scores = ae_model(inputs)
19
+ # scores returns 7 dims; first is overall aesthetic
20
+ return float(scores[0][0].item())
21
+
22
+ # Enhancement using public Zero-DCE model
23
+ zero_dce_pipe = pipeline(
24
+ "image-enhancement",
25
+ model="nateraw/zero-dce",
26
+ device=0 if torch.cuda.is_available() else -1
27
+ )
28
 
29
+ def enhance_image(img_pil):
30
  enhanced = zero_dce_pipe(img_pil)
31
  return enhanced[0]
32
 
33
+ # Image Classifier (ViT)
34
+ cls_ext = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
35
+ cls_model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224").to(device)
 
36
  cls_model.eval()
37
 
38
  def classify_image(img_pil):
39
+ inputs = cls_ext(images=img_pil, return_tensors="pt").to(device)
40
  with torch.no_grad():
41
  logits = cls_model(**inputs).logits
42
+ label = cls_model.config.id2label[logits.argmax(-1).item()].lower()
43
+ return label
 
44
 
45
+ # Category-specific vibes
46
  CATEGORY_VIBES = {
47
+ "person": [...], # same presets as before
48
+ "food": [...],
49
+ "landscape": [...],
50
+ "default": [...],
 
 
 
 
 
 
 
 
 
 
 
 
51
  }
52
 
53
+ def apply_adjustments(img, exposure, contrast, saturation, warmth, clarity):
54
+ img = img.convert("RGB")
55
+ if exposure: img = ImageEnhance.Brightness(img).enhance(2**exposure)
56
+ if contrast: img = ImageEnhance.Contrast(img).enhance(1 + contrast)
57
+ if saturation: img = ImageEnhance.Color(img).enhance(1 + saturation)
58
+ if clarity:
 
 
 
 
 
59
  arr = np.array(img).astype(np.float32)
60
  arr = np.clip(arr * (1 + clarity), 0, 255).astype(np.uint8)
61
  img = Image.fromarray(arr)
62
+ if warmth:
63
  r, g, b = img.split()
64
+ r = r.point(lambda i: min(255, i*(1+warmth)))
65
+ b = b.point(lambda i: min(255, i*(1-warmth)))
66
+ img = Image.merge("RGB",(r,g,b))
67
  return img
68
 
 
69
  def process(image):
70
+ enhanced = enhance_image(image)
 
 
 
71
  label = classify_image(enhanced)
72
+ vibes = CATEGORY_VIBES.get(label, CATEGORY_VIBES["default"])
73
+
74
+ best, best_score, best_name = None, -float("inf"), None
 
 
 
 
 
 
 
 
 
 
75
  for vibe in vibes:
76
+ out = apply_adjustments(enhanced, **vibe)
77
+ score = aesthetic_score(out)
78
+ if score > best_score:
79
+ best, best_score, best_name = out, score, vibe["name"]
80
+
81
+ return best, f"Classified as {label} β†’ Chosen style: {best_name} (score {best_score:.2f})"
 
82
 
 
 
 
83
  demo = gr.Interface(
84
  fn=process,
85
  inputs=gr.Image(type="pil"),
86
+ outputs=[gr.Image(type="pil"), gr.Text()],
87
+ title="Content-Aware Aesthetic AI (Public)",
88
+ description="Enhance β†’ classify β†’ apply category vibes β†’ score with public aesthetic model"
89
  )
 
90
  if __name__ == "__main__":
91
  demo.launch()