Mavthunder commited on
Commit
8071e3e
·
verified ·
1 Parent(s): 72df98a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -49
app.py CHANGED
@@ -1,67 +1,113 @@
 
1
  import torch
2
  import torch.nn as nn
3
- import torchvision.transforms as transforms
4
- from PIL import Image
5
  import numpy as np
6
- import gradio as gr
7
 
8
- # -----------------------
9
- # Zero-DCE Model
10
- # -----------------------
11
- class ZeroDCE(nn.Module):
12
- def __init__(self, num_layers=7):
13
- super(ZeroDCE, self).__init__()
14
- filters = 32
15
- layers = []
16
- layers += [nn.Conv2d(3, filters, 3, 1, 1), nn.ReLU(inplace=True)]
17
- for _ in range(num_layers - 2):
18
- layers += [nn.Conv2d(filters, filters, 3, 1, 1), nn.ReLU(inplace=True)]
19
- layers += [nn.Conv2d(filters, 3, 3, 1, 1), nn.Tanh()]
20
- self.net = nn.Sequential(*layers)
21
 
22
- def forward(self, x):
23
- return x + self.net(x)
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- # -----------------------
26
- # Setup
27
- # -----------------------
28
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
- zero_dce = ZeroDCE().to(device)
30
 
31
- # -----------------------
32
- # Image Enhancement
33
- # -----------------------
34
- def enhance_image(image, model):
35
- # Always force RGB
36
- image = image.convert("RGB")
37
 
38
- transform = transforms.Compose([
39
- transforms.ToTensor(),
40
- ])
41
- img_tensor = transform(image).unsqueeze(0).to(device)
42
 
 
 
43
  with torch.no_grad():
44
- enhanced = model(img_tensor)
 
 
45
 
46
- # Convert back to PIL
47
- enhanced = enhanced.squeeze(0).cpu().permute(1, 2, 0).numpy()
48
- enhanced = (enhanced * 255).clip(0, 255).astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- return Image.fromarray(enhanced)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- # -----------------------
53
- # Gradio Interface
54
- # -----------------------
55
- def process_image(input_img):
56
- return enhance_image(input_img, zero_dce)
 
 
 
 
 
 
 
 
 
 
57
 
 
58
  with gr.Blocks() as demo:
59
- gr.Markdown("## 🌙 Low Light Image Enhancement (Zero-DCE)")
60
  with gr.Row():
61
- inp = gr.Image(type="pil", label="Upload Dark Image")
62
- out = gr.Image(type="pil", label="Enhanced Image")
63
- run_btn = gr.Button("Enhance")
64
- run_btn.click(fn=process_image, inputs=inp, outputs=out)
 
 
65
 
66
  if __name__ == "__main__":
67
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
+ import gradio as gr
2
  import torch
3
  import torch.nn as nn
4
+ from transformers import CLIPProcessor, CLIPModel, ViTFeatureExtractor, ViTForImageClassification
5
+ from PIL import Image, ImageEnhance, ImageFilter, ImageOps
6
  import numpy as np
7
+ import cv2
8
 
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ # ----------------- Aesthetic Scoring (CLIP + regression head) -----------------
12
+ class AestheticPredictor(nn.Module):
13
+ def __init__(self):
14
+ super().__init__()
15
+ self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
16
+ self.mlp = nn.Sequential(
17
+ nn.Linear(self.clip.config.projection_dim, 512),
18
+ nn.ReLU(),
19
+ nn.Linear(512, 1)
20
+ )
21
+ def forward(self, pixel_values):
22
+ outputs = self.clip(pixel_values=pixel_values).pooler_output
23
+ return self.mlp(outputs)
24
 
25
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
26
+ ae_model = AestheticPredictor().to(device)
27
+ ae_model.eval()
 
 
28
 
29
+ def aesthetic_score(img):
30
+ inputs = processor(images=img, return_tensors="pt").to(device)
31
+ with torch.no_grad():
32
+ s = ae_model(inputs['pixel_values'])
33
+ return float(s.item())
 
34
 
35
+ # ----------------- Classifier -----------------
36
+ cls_processor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
37
+ cls_model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224").to(device)
38
+ cls_model.eval()
39
 
40
+ def classify(img):
41
+ inputs = cls_processor(images=img, return_tensors="pt").to(device)
42
  with torch.no_grad():
43
+ logits = cls_model(**inputs).logits
44
+ label = cls_model.config.id2label[logits.argmax(-1).item()]
45
+ return label.lower()
46
 
47
+ # ----------------- Cinematic Filters (hardcoded) -----------------
48
+ CINEMATIC_STYLES = {
49
+ "portrait": [
50
+ dict(name="Warm Glow", gamma=0.9, contrast=1.3, saturation=1.2, color_tint=(240, 200, 180), vignette=0.4),
51
+ dict(name="Moody Teal-Orange", gamma=1.0, contrast=1.4, saturation=1.0, color_tint=(30, 210, 180), vignette=0.5),
52
+ ],
53
+ "landscape": [
54
+ dict(name="Epic Teal-Orange", gamma=1.0, contrast=1.5, saturation=1.3, color_tint=(25, 200, 160), vignette=0.3),
55
+ dict(name="Soft Film Look", gamma=0.95, contrast=1.2, saturation=1.1, color_tint=(220, 200, 180), vignette=0.2),
56
+ ],
57
+ "default": [
58
+ dict(name="Classic Cinema", gamma=1.0, contrast=1.4, saturation=1.1, color_tint=(220, 180, 160), vignette=0.3),
59
+ ],
60
+ }
61
 
62
+ def apply_cinematic(img, style):
63
+ img = img.convert("RGB")
64
+ # gamma
65
+ arr = np.array(img).astype(np.float32) / 255.0
66
+ arr = arr ** style["gamma"]
67
+ # contrast & saturation
68
+ img = Image.fromarray((arr*255).astype(np.uint8))
69
+ img = ImageEnhance.Contrast(img).enhance(style["contrast"])
70
+ img = ImageEnhance.Color(img).enhance(style["saturation"])
71
+ # color tint via overlay
72
+ tint = Image.new("RGB", img.size, style["color_tint"])
73
+ img = Image.blend(img, tint, alpha=0.1)
74
+ # vignette
75
+ if style["vignette"]>0:
76
+ w, h = img.size
77
+ mask = Image.new("L", (w,h), 255)
78
+ draw = ImageDraw.Draw(mask)
79
+ draw.ellipse([(-w*style["vignette"], -h*style["vignette"]),
80
+ (w*(1+style["vignette"]), h*(1+style["vignette"]))],
81
+ fill=0)
82
+ img = Image.composite(img, ImageOps.colorize(mask, (0,0,0), (0,0,0)), mask)
83
+ return img
84
 
85
+ # ----------------- Pipeline -----------------
86
+ def process(img):
87
+ label = classify(img)
88
+ key = "portrait" if "person" in label else ("landscape" if "landscape" in label else "default")
89
+ styles = CINEMATIC_STYLES.get(key, CINEMATIC_STYLES["default"])
90
+ candidates = []
91
+ for style in styles:
92
+ out = apply_cinematic(img, style)
93
+ score = aesthetic_score(out)
94
+ candidates.append((score, style["name"], out))
95
+ candidates.sort(reverse=True, key=lambda x: x[0])
96
+ # Output winner, plus gallery
97
+ gallery = [(o, f"{name}: {s:.2f}") for s,name,o in candidates]
98
+ winner_score, winner_name, winner_img = candidates[0]
99
+ return winner_img, f"Barely cinematic vibe: **{winner_name}** (score {winner_score:.2f})", gallery
100
 
101
+ # ----------------- UI -----------------
102
  with gr.Blocks() as demo:
103
+ gr.Markdown("## Cinematic AI Instant Film-Grade Style")
104
  with gr.Row():
105
+ inp = gr.Image(type="pil", label="Upload Your Photo")
106
+ out = gr.Image(type="pil", label="Cinematic Result")
107
+ btn = gr.Button("Make Cinematic")
108
+ info = gr.Markdown()
109
+ gallery = gr.Gallery(label="All Styles (ranked)", columns=2)
110
+ btn.click(process, [inp], [out, info, gallery])
111
 
112
  if __name__ == "__main__":
113
+ demo.launch()