Mavthunder commited on
Commit
43a6e4e
·
verified ·
1 Parent(s): 7dc26b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -92
app.py CHANGED
@@ -1,113 +1,67 @@
1
- import gradio as gr
2
  import torch
3
  import torch.nn as nn
4
- from transformers import CLIPProcessor, CLIPModel
5
  from PIL import Image
6
  import numpy as np
 
7
 
8
- # -----------------------------
9
- # 1. Zero-DCE model (light enhancement)
10
- # -----------------------------
11
  class ZeroDCE(nn.Module):
12
- def __init__(self):
13
  super(ZeroDCE, self).__init__()
14
- self.relu = nn.ReLU(inplace=True)
15
- self.conv1 = nn.Conv2d(3, 32, 3, padding=1, bias=True)
16
- self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
17
- self.conv3 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
18
- self.conv4 = nn.Conv2d(32, 24, 3, padding=1, bias=True)
19
- self.conv5 = nn.Conv2d(24, 8, 3, padding=1, bias=True)
 
20
 
21
  def forward(self, x):
22
- x1 = self.relu(self.conv1(x))
23
- x2 = self.relu(self.conv2(x1))
24
- x3 = self.relu(self.conv3(x2))
25
- x4 = self.relu(self.conv4(x3))
26
- out = torch.tanh(self.conv5(x4))
27
- return out
28
-
29
- def enhance_image(img, model):
30
- # Convert PIL -> Tensor
31
- img_np = np.array(img).astype(np.float32) / 255.0
32
- img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(device)
33
-
34
- with torch.no_grad():
35
- enhancement_map = model(img_tensor)
36
 
37
- # Apply enhancement: simple iterative curve
38
- enhanced = img_tensor
39
- for i in range(8):
40
- enhanced = enhanced + enhancement_map[:, i*3:(i+1)*3, :, :] * (enhanced**2 - enhanced)
41
-
42
- enhanced = torch.clamp(enhanced, 0, 1)
43
- enhanced = enhanced.squeeze(0).permute(1, 2, 0).cpu().numpy()
44
- enhanced = (enhanced * 255).astype(np.uint8)
45
-
46
- return Image.fromarray(enhanced)
47
 
48
- # -----------------------------
49
- # 2. Aesthetic Scoring Model
50
- # -----------------------------
51
- class AestheticPredictor(nn.Module):
52
- def __init__(self):
53
- super().__init__()
54
- self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
55
- self.mlp = nn.Sequential(
56
- nn.Linear(self.clip.config.projection_dim, 512),
57
- nn.ReLU(),
58
- nn.Linear(512, 1)
59
- )
60
 
61
- def forward(self, pixel_values, input_ids, attention_mask):
62
- outputs = self.clip(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask)
63
- pooled_output = outputs.pooler_output
64
- return self.mlp(pooled_output)
65
 
66
- def score_image(image, processor, model):
67
- inputs = processor(text=["aesthetic photo"], images=image, return_tensors="pt", padding=True).to(device)
68
  with torch.no_grad():
69
- score = model(**inputs)
70
- return score.item()
71
 
72
- # -----------------------------
73
- # 3. Pipeline function
74
- # -----------------------------
75
- def process_image(input_img):
76
- # Step 1: enhance
77
- enhanced_img = enhance_image(input_img, zero_dce)
78
-
79
- # Step 2: aesthetic scoring
80
- original_score = score_image(input_img, processor, ae_model)
81
- enhanced_score = score_image(enhanced_img, processor, ae_model)
82
-
83
- # Step 3: choose best
84
- if enhanced_score > original_score:
85
- return enhanced_img, f"Enhanced chosen (score {enhanced_score:.2f} vs {original_score:.2f})"
86
- else:
87
- return input_img, f"Original kept (score {original_score:.2f} vs {enhanced_score:.2f})"
88
 
89
- # -----------------------------
90
- # 4. Setup
91
- # -----------------------------
92
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
93
 
94
- zero_dce = ZeroDCE().to(device)
95
- ae_model = AestheticPredictor().to(device)
96
- processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
 
 
97
 
98
- # -----------------------------
99
- # 5. Gradio UI
100
- # -----------------------------
101
  with gr.Blocks() as demo:
102
- gr.Markdown("## 📸 AI Photo Enhancer with Aesthetic Scoring")
103
-
104
  with gr.Row():
105
- inp = gr.Image(type="pil", label="Upload your photo")
106
- out = gr.Image(type="pil", label="Best looking result")
107
-
108
- info = gr.Label(label="Result Info")
109
-
110
- btn = gr.Button("Enhance ✨")
111
- btn.click(process_image, inputs=inp, outputs=[out, info])
112
 
113
- demo.launch()
 
 
 
1
  import torch
2
  import torch.nn as nn
3
+ import torchvision.transforms as transforms
4
  from PIL import Image
5
  import numpy as np
6
+ import gradio as gr
7
 
8
+ # -----------------------
9
+ # Zero-DCE Model
10
+ # -----------------------
11
  class ZeroDCE(nn.Module):
12
+ def __init__(self, num_layers=7):
13
  super(ZeroDCE, self).__init__()
14
+ filters = 32
15
+ layers = []
16
+ layers += [nn.Conv2d(3, filters, 3, 1, 1), nn.ReLU(inplace=True)]
17
+ for _ in range(num_layers - 2):
18
+ layers += [nn.Conv2d(filters, filters, 3, 1, 1), nn.ReLU(inplace=True)]
19
+ layers += [nn.Conv2d(filters, 3, 3, 1, 1), nn.Tanh()]
20
+ self.net = nn.Sequential(*layers)
21
 
22
  def forward(self, x):
23
+ return x + self.net(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ # -----------------------
26
+ # Setup
27
+ # -----------------------
28
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+ zero_dce = ZeroDCE().to(device)
 
 
 
 
 
30
 
31
+ # -----------------------
32
+ # Image Enhancement
33
+ # -----------------------
34
+ def enhance_image(image, model):
35
+ # Always force RGB
36
+ image = image.convert("RGB")
 
 
 
 
 
 
37
 
38
+ transform = transforms.Compose([
39
+ transforms.ToTensor(),
40
+ ])
41
+ img_tensor = transform(image).unsqueeze(0).to(device)
42
 
 
 
43
  with torch.no_grad():
44
+ enhanced = model(img_tensor)
 
45
 
46
+ # Convert back to PIL
47
+ enhanced = enhanced.squeeze(0).cpu().permute(1, 2, 0).numpy()
48
+ enhanced = (enhanced * 255).clip(0, 255).astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
+ return Image.fromarray(enhanced)
 
 
 
51
 
52
+ # -----------------------
53
+ # Gradio Interface
54
+ # -----------------------
55
+ def process_image(input_img):
56
+ return enhance_image(input_img, zero_dce)
57
 
 
 
 
58
  with gr.Blocks() as demo:
59
+ gr.Markdown("## 🌙 Low Light Image Enhancement (Zero-DCE)")
 
60
  with gr.Row():
61
+ inp = gr.Image(type="pil", label="Upload Dark Image")
62
+ out = gr.Image(type="pil", label="Enhanced Image")
63
+ run_btn = gr.Button("Enhance")
64
+ run_btn.click(fn=process_image, inputs=inp, outputs=out)
 
 
 
65
 
66
+ if __name__ == "__main__":
67
+ demo.launch(server_name="0.0.0.0", server_port=7860)