Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from transformers import CLIPProcessor, CLIPModel | |
| from PIL import Image | |
| import numpy as np | |
| # ----------------------------- | |
| # 1. Zero-DCE model (light enhancement) | |
| # ----------------------------- | |
| class ZeroDCE(nn.Module): | |
| def __init__(self): | |
| super(ZeroDCE, self).__init__() | |
| self.relu = nn.ReLU(inplace=True) | |
| self.conv1 = nn.Conv2d(3, 32, 3, padding=1, bias=True) | |
| self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=True) | |
| self.conv3 = nn.Conv2d(32, 32, 3, padding=1, bias=True) | |
| self.conv4 = nn.Conv2d(32, 24, 3, padding=1, bias=True) | |
| self.conv5 = nn.Conv2d(24, 8, 3, padding=1, bias=True) | |
| def forward(self, x): | |
| x1 = self.relu(self.conv1(x)) | |
| x2 = self.relu(self.conv2(x1)) | |
| x3 = self.relu(self.conv3(x2)) | |
| x4 = self.relu(self.conv4(x3)) | |
| out = torch.tanh(self.conv5(x4)) | |
| return out | |
| def enhance_image(img, model): | |
| # Convert PIL -> Tensor | |
| img_np = np.array(img).astype(np.float32) / 255.0 | |
| img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| enhancement_map = model(img_tensor) | |
| # Apply enhancement: simple iterative curve | |
| enhanced = img_tensor | |
| for i in range(8): | |
| enhanced = enhanced + enhancement_map[:, i*3:(i+1)*3, :, :] * (enhanced**2 - enhanced) | |
| enhanced = torch.clamp(enhanced, 0, 1) | |
| enhanced = enhanced.squeeze(0).permute(1, 2, 0).cpu().numpy() | |
| enhanced = (enhanced * 255).astype(np.uint8) | |
| return Image.fromarray(enhanced) | |
| # ----------------------------- | |
| # 2. Aesthetic Scoring Model | |
| # ----------------------------- | |
| class AestheticPredictor(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch16") | |
| self.mlp = nn.Sequential( | |
| nn.Linear(self.clip.config.projection_dim, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, 1) | |
| ) | |
| def forward(self, pixel_values, input_ids, attention_mask): | |
| outputs = self.clip(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask) | |
| pooled_output = outputs.pooler_output | |
| return self.mlp(pooled_output) | |
| def score_image(image, processor, model): | |
| inputs = processor(text=["aesthetic photo"], images=image, return_tensors="pt", padding=True).to(device) | |
| with torch.no_grad(): | |
| score = model(**inputs) | |
| return score.item() | |
| # ----------------------------- | |
| # 3. Pipeline function | |
| # ----------------------------- | |
| def process_image(input_img): | |
| # Step 1: enhance | |
| enhanced_img = enhance_image(input_img, zero_dce) | |
| # Step 2: aesthetic scoring | |
| original_score = score_image(input_img, processor, ae_model) | |
| enhanced_score = score_image(enhanced_img, processor, ae_model) | |
| # Step 3: choose best | |
| if enhanced_score > original_score: | |
| return enhanced_img, f"Enhanced chosen (score {enhanced_score:.2f} vs {original_score:.2f})" | |
| else: | |
| return input_img, f"Original kept (score {original_score:.2f} vs {enhanced_score:.2f})" | |
| # ----------------------------- | |
| # 4. Setup | |
| # ----------------------------- | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| zero_dce = ZeroDCE().to(device) | |
| ae_model = AestheticPredictor().to(device) | |
| processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16") | |
| # ----------------------------- | |
| # 5. Gradio UI | |
| # ----------------------------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## 📸 AI Photo Enhancer with Aesthetic Scoring") | |
| with gr.Row(): | |
| inp = gr.Image(type="pil", label="Upload your photo") | |
| out = gr.Image(type="pil", label="Best looking result") | |
| info = gr.Label(label="Result Info") | |
| btn = gr.Button("Enhance ✨") | |
| btn.click(process_image, inputs=inp, outputs=[out, info]) | |
| demo.launch() | |