import gradio as gr import torch import torch.nn as nn from transformers import CLIPProcessor, CLIPModel from PIL import Image import numpy as np # ----------------------------- # 1. Zero-DCE model (light enhancement) # ----------------------------- class ZeroDCE(nn.Module): def __init__(self): super(ZeroDCE, self).__init__() self.conv1 = nn.Conv2d(3, 32, 3, padding=1) self.conv2 = nn.Conv2d(32, 32, 3, padding=1) self.conv3 = nn.Conv2d(32, 32, 3, padding=1) self.conv4 = nn.Conv2d(32, 24, 3, padding=1) self.relu = nn.ReLU(inplace=True) def forward(self, x): x1 = self.relu(self.conv1(x)) x2 = self.relu(self.conv2(x1)) x3 = self.relu(self.conv3(x2)) x_r = torch.tanh(self.conv4(x3)) return x_r def enhance_image(img, model): img_tensor = torch.from_numpy(np.array(img)).float() / 255.0 img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0).to(device) with torch.no_grad(): enhanced = model(img_tensor) enhanced = enhanced.squeeze(0).permute(1, 2, 0).cpu().numpy() enhanced = np.clip(enhanced * 255, 0, 255).astype(np.uint8) return Image.fromarray(enhanced) # ----------------------------- # 2. Aesthetic Scoring Model # ----------------------------- class AestheticPredictor(nn.Module): def __init__(self): super().__init__() self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch16") self.mlp = nn.Sequential( nn.Linear(self.clip.config.projection_dim, 512), nn.ReLU(), nn.Linear(512, 1) ) def forward(self, pixel_values, input_ids, attention_mask): outputs = self.clip(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask) pooled_output = outputs.pooler_output return self.mlp(pooled_output) def score_image(image, processor, model): inputs = processor(text=["aesthetic photo"], images=image, return_tensors="pt", padding=True).to(device) with torch.no_grad(): score = model(**inputs) return score.item() # ----------------------------- # 3. Pipeline function # ----------------------------- def process_image(input_img): # Step 1: enhance enhanced_img = enhance_image(input_img, zero_dce) # Step 2: aesthetic scoring original_score = score_image(input_img, processor, ae_model) enhanced_score = score_image(enhanced_img, processor, ae_model) # Step 3: choose best if enhanced_score > original_score: return enhanced_img, f"Enhanced chosen (score {enhanced_score:.2f} vs {original_score:.2f})" else: return input_img, f"Original kept (score {original_score:.2f} vs {enhanced_score:.2f})" # ----------------------------- # 4. Setup # ----------------------------- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") zero_dce = ZeroDCE().to(device) ae_model = AestheticPredictor().to(device) processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16") # ----------------------------- # 5. Gradio UI # ----------------------------- with gr.Blocks() as demo: gr.Markdown("## 📸 AI Photo Enhancer with Aesthetic Scoring") with gr.Row(): inp = gr.Image(type="pil", label="Upload your photo") out = gr.Image(type="pil", label="Best looking result") info = gr.Label(label="Result Info") btn = gr.Button("Enhance ✨") btn.click(process_image, inputs=inp, outputs=[out, info]) demo.launch()