File size: 3,919 Bytes
edf1515
d379ce4
de7814e
 
 
 
d379ce4
de7814e
 
 
 
 
 
 
7dc26b7
 
 
 
 
d379ce4
de7814e
 
 
 
7dc26b7
 
 
d379ce4
de7814e
7dc26b7
 
 
 
d379ce4
7dc26b7
 
 
 
 
 
 
 
de7814e
7dc26b7
 
de7814e
5fe3046
de7814e
 
 
 
 
 
 
 
 
 
 
 
5fe3046
de7814e
 
 
 
5fe3046
de7814e
 
5fe3046
de7814e
 
5fe3046
de7814e
 
 
 
 
 
edf1515
de7814e
 
 
5fe3046
de7814e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf39539
de7814e
 
 
 
5fe3046
de7814e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import gradio as gr
import torch
import torch.nn as nn
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import numpy as np

# -----------------------------
# 1. Zero-DCE model (light enhancement)
# -----------------------------
class ZeroDCE(nn.Module):
    def __init__(self):
        super(ZeroDCE, self).__init__()
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1, bias=True)
        self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
        self.conv3 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
        self.conv4 = nn.Conv2d(32, 24, 3, padding=1, bias=True)
        self.conv5 = nn.Conv2d(24, 8, 3, padding=1, bias=True)

    def forward(self, x):
        x1 = self.relu(self.conv1(x))
        x2 = self.relu(self.conv2(x1))
        x3 = self.relu(self.conv3(x2))
        x4 = self.relu(self.conv4(x3))
        out = torch.tanh(self.conv5(x4))
        return out

def enhance_image(img, model):
    # Convert PIL -> Tensor
    img_np = np.array(img).astype(np.float32) / 255.0
    img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(device)

    with torch.no_grad():
        enhancement_map = model(img_tensor)

    # Apply enhancement: simple iterative curve
    enhanced = img_tensor
    for i in range(8):
        enhanced = enhanced + enhancement_map[:, i*3:(i+1)*3, :, :] * (enhanced**2 - enhanced)

    enhanced = torch.clamp(enhanced, 0, 1)
    enhanced = enhanced.squeeze(0).permute(1, 2, 0).cpu().numpy()
    enhanced = (enhanced * 255).astype(np.uint8)

    return Image.fromarray(enhanced)

# -----------------------------
# 2. Aesthetic Scoring Model
# -----------------------------
class AestheticPredictor(nn.Module):
    def __init__(self):
        super().__init__()
        self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
        self.mlp = nn.Sequential(
            nn.Linear(self.clip.config.projection_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 1)
        )

    def forward(self, pixel_values, input_ids, attention_mask):
        outputs = self.clip(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        return self.mlp(pooled_output)

def score_image(image, processor, model):
    inputs = processor(text=["aesthetic photo"], images=image, return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        score = model(**inputs)
    return score.item()

# -----------------------------
# 3. Pipeline function
# -----------------------------
def process_image(input_img):
    # Step 1: enhance
    enhanced_img = enhance_image(input_img, zero_dce)

    # Step 2: aesthetic scoring
    original_score = score_image(input_img, processor, ae_model)
    enhanced_score = score_image(enhanced_img, processor, ae_model)

    # Step 3: choose best
    if enhanced_score > original_score:
        return enhanced_img, f"Enhanced chosen (score {enhanced_score:.2f} vs {original_score:.2f})"
    else:
        return input_img, f"Original kept (score {original_score:.2f} vs {enhanced_score:.2f})"

# -----------------------------
# 4. Setup
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

zero_dce = ZeroDCE().to(device)
ae_model = AestheticPredictor().to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")

# -----------------------------
# 5. Gradio UI
# -----------------------------
with gr.Blocks() as demo:
    gr.Markdown("## 📸 AI Photo Enhancer with Aesthetic Scoring")

    with gr.Row():
        inp = gr.Image(type="pil", label="Upload your photo")
        out = gr.Image(type="pil", label="Best looking result")
    
    info = gr.Label(label="Result Info")

    btn = gr.Button("Enhance ✨")
    btn.click(process_image, inputs=inp, outputs=[out, info])

demo.launch()