Spaces:
Sleeping
Sleeping
File size: 3,919 Bytes
edf1515 d379ce4 de7814e d379ce4 de7814e 7dc26b7 d379ce4 de7814e 7dc26b7 d379ce4 de7814e 7dc26b7 d379ce4 7dc26b7 de7814e 7dc26b7 de7814e 5fe3046 de7814e 5fe3046 de7814e 5fe3046 de7814e 5fe3046 de7814e 5fe3046 de7814e edf1515 de7814e 5fe3046 de7814e cf39539 de7814e 5fe3046 de7814e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import gradio as gr
import torch
import torch.nn as nn
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import numpy as np
# -----------------------------
# 1. Zero-DCE model (light enhancement)
# -----------------------------
class ZeroDCE(nn.Module):
def __init__(self):
super(ZeroDCE, self).__init__()
self.relu = nn.ReLU(inplace=True)
self.conv1 = nn.Conv2d(3, 32, 3, padding=1, bias=True)
self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
self.conv3 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
self.conv4 = nn.Conv2d(32, 24, 3, padding=1, bias=True)
self.conv5 = nn.Conv2d(24, 8, 3, padding=1, bias=True)
def forward(self, x):
x1 = self.relu(self.conv1(x))
x2 = self.relu(self.conv2(x1))
x3 = self.relu(self.conv3(x2))
x4 = self.relu(self.conv4(x3))
out = torch.tanh(self.conv5(x4))
return out
def enhance_image(img, model):
# Convert PIL -> Tensor
img_np = np.array(img).astype(np.float32) / 255.0
img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(device)
with torch.no_grad():
enhancement_map = model(img_tensor)
# Apply enhancement: simple iterative curve
enhanced = img_tensor
for i in range(8):
enhanced = enhanced + enhancement_map[:, i*3:(i+1)*3, :, :] * (enhanced**2 - enhanced)
enhanced = torch.clamp(enhanced, 0, 1)
enhanced = enhanced.squeeze(0).permute(1, 2, 0).cpu().numpy()
enhanced = (enhanced * 255).astype(np.uint8)
return Image.fromarray(enhanced)
# -----------------------------
# 2. Aesthetic Scoring Model
# -----------------------------
class AestheticPredictor(nn.Module):
def __init__(self):
super().__init__()
self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
self.mlp = nn.Sequential(
nn.Linear(self.clip.config.projection_dim, 512),
nn.ReLU(),
nn.Linear(512, 1)
)
def forward(self, pixel_values, input_ids, attention_mask):
outputs = self.clip(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.pooler_output
return self.mlp(pooled_output)
def score_image(image, processor, model):
inputs = processor(text=["aesthetic photo"], images=image, return_tensors="pt", padding=True).to(device)
with torch.no_grad():
score = model(**inputs)
return score.item()
# -----------------------------
# 3. Pipeline function
# -----------------------------
def process_image(input_img):
# Step 1: enhance
enhanced_img = enhance_image(input_img, zero_dce)
# Step 2: aesthetic scoring
original_score = score_image(input_img, processor, ae_model)
enhanced_score = score_image(enhanced_img, processor, ae_model)
# Step 3: choose best
if enhanced_score > original_score:
return enhanced_img, f"Enhanced chosen (score {enhanced_score:.2f} vs {original_score:.2f})"
else:
return input_img, f"Original kept (score {original_score:.2f} vs {enhanced_score:.2f})"
# -----------------------------
# 4. Setup
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
zero_dce = ZeroDCE().to(device)
ae_model = AestheticPredictor().to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
# -----------------------------
# 5. Gradio UI
# -----------------------------
with gr.Blocks() as demo:
gr.Markdown("## 📸 AI Photo Enhancer with Aesthetic Scoring")
with gr.Row():
inp = gr.Image(type="pil", label="Upload your photo")
out = gr.Image(type="pil", label="Best looking result")
info = gr.Label(label="Result Info")
btn = gr.Button("Enhance ✨")
btn.click(process_image, inputs=inp, outputs=[out, info])
demo.launch()
|