Mavthunder's picture
Update app.py
43a6e4e verified
raw
history blame
1.99 kB
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import gradio as gr
# -----------------------
# Zero-DCE Model
# -----------------------
class ZeroDCE(nn.Module):
def __init__(self, num_layers=7):
super(ZeroDCE, self).__init__()
filters = 32
layers = []
layers += [nn.Conv2d(3, filters, 3, 1, 1), nn.ReLU(inplace=True)]
for _ in range(num_layers - 2):
layers += [nn.Conv2d(filters, filters, 3, 1, 1), nn.ReLU(inplace=True)]
layers += [nn.Conv2d(filters, 3, 3, 1, 1), nn.Tanh()]
self.net = nn.Sequential(*layers)
def forward(self, x):
return x + self.net(x)
# -----------------------
# Setup
# -----------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
zero_dce = ZeroDCE().to(device)
# -----------------------
# Image Enhancement
# -----------------------
def enhance_image(image, model):
# Always force RGB
image = image.convert("RGB")
transform = transforms.Compose([
transforms.ToTensor(),
])
img_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
enhanced = model(img_tensor)
# Convert back to PIL
enhanced = enhanced.squeeze(0).cpu().permute(1, 2, 0).numpy()
enhanced = (enhanced * 255).clip(0, 255).astype(np.uint8)
return Image.fromarray(enhanced)
# -----------------------
# Gradio Interface
# -----------------------
def process_image(input_img):
return enhance_image(input_img, zero_dce)
with gr.Blocks() as demo:
gr.Markdown("## πŸŒ™ Low Light Image Enhancement (Zero-DCE)")
with gr.Row():
inp = gr.Image(type="pil", label="Upload Dark Image")
out = gr.Image(type="pil", label="Enhanced Image")
run_btn = gr.Button("Enhance")
run_btn.click(fn=process_image, inputs=inp, outputs=out)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)