Mavthunder's picture
Update app.py
7dc26b7 verified
raw
history blame
3.92 kB
import gradio as gr
import torch
import torch.nn as nn
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import numpy as np
# -----------------------------
# 1. Zero-DCE model (light enhancement)
# -----------------------------
class ZeroDCE(nn.Module):
def __init__(self):
super(ZeroDCE, self).__init__()
self.relu = nn.ReLU(inplace=True)
self.conv1 = nn.Conv2d(3, 32, 3, padding=1, bias=True)
self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
self.conv3 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
self.conv4 = nn.Conv2d(32, 24, 3, padding=1, bias=True)
self.conv5 = nn.Conv2d(24, 8, 3, padding=1, bias=True)
def forward(self, x):
x1 = self.relu(self.conv1(x))
x2 = self.relu(self.conv2(x1))
x3 = self.relu(self.conv3(x2))
x4 = self.relu(self.conv4(x3))
out = torch.tanh(self.conv5(x4))
return out
def enhance_image(img, model):
# Convert PIL -> Tensor
img_np = np.array(img).astype(np.float32) / 255.0
img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(device)
with torch.no_grad():
enhancement_map = model(img_tensor)
# Apply enhancement: simple iterative curve
enhanced = img_tensor
for i in range(8):
enhanced = enhanced + enhancement_map[:, i*3:(i+1)*3, :, :] * (enhanced**2 - enhanced)
enhanced = torch.clamp(enhanced, 0, 1)
enhanced = enhanced.squeeze(0).permute(1, 2, 0).cpu().numpy()
enhanced = (enhanced * 255).astype(np.uint8)
return Image.fromarray(enhanced)
# -----------------------------
# 2. Aesthetic Scoring Model
# -----------------------------
class AestheticPredictor(nn.Module):
def __init__(self):
super().__init__()
self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
self.mlp = nn.Sequential(
nn.Linear(self.clip.config.projection_dim, 512),
nn.ReLU(),
nn.Linear(512, 1)
)
def forward(self, pixel_values, input_ids, attention_mask):
outputs = self.clip(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.pooler_output
return self.mlp(pooled_output)
def score_image(image, processor, model):
inputs = processor(text=["aesthetic photo"], images=image, return_tensors="pt", padding=True).to(device)
with torch.no_grad():
score = model(**inputs)
return score.item()
# -----------------------------
# 3. Pipeline function
# -----------------------------
def process_image(input_img):
# Step 1: enhance
enhanced_img = enhance_image(input_img, zero_dce)
# Step 2: aesthetic scoring
original_score = score_image(input_img, processor, ae_model)
enhanced_score = score_image(enhanced_img, processor, ae_model)
# Step 3: choose best
if enhanced_score > original_score:
return enhanced_img, f"Enhanced chosen (score {enhanced_score:.2f} vs {original_score:.2f})"
else:
return input_img, f"Original kept (score {original_score:.2f} vs {enhanced_score:.2f})"
# -----------------------------
# 4. Setup
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
zero_dce = ZeroDCE().to(device)
ae_model = AestheticPredictor().to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
# -----------------------------
# 5. Gradio UI
# -----------------------------
with gr.Blocks() as demo:
gr.Markdown("## 📸 AI Photo Enhancer with Aesthetic Scoring")
with gr.Row():
inp = gr.Image(type="pil", label="Upload your photo")
out = gr.Image(type="pil", label="Best looking result")
info = gr.Label(label="Result Info")
btn = gr.Button("Enhance ✨")
btn.click(process_image, inputs=inp, outputs=[out, info])
demo.launch()