Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| import numpy as np | |
| import gradio as gr | |
| # ----------------------- | |
| # Zero-DCE Model | |
| # ----------------------- | |
| class ZeroDCE(nn.Module): | |
| def __init__(self, num_layers=7): | |
| super(ZeroDCE, self).__init__() | |
| filters = 32 | |
| layers = [] | |
| layers += [nn.Conv2d(3, filters, 3, 1, 1), nn.ReLU(inplace=True)] | |
| for _ in range(num_layers - 2): | |
| layers += [nn.Conv2d(filters, filters, 3, 1, 1), nn.ReLU(inplace=True)] | |
| layers += [nn.Conv2d(filters, 3, 3, 1, 1), nn.Tanh()] | |
| self.net = nn.Sequential(*layers) | |
| def forward(self, x): | |
| return x + self.net(x) | |
| # ----------------------- | |
| # Setup | |
| # ----------------------- | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| zero_dce = ZeroDCE().to(device) | |
| # ----------------------- | |
| # Image Enhancement | |
| # ----------------------- | |
| def enhance_image(image, model): | |
| # Always force RGB | |
| image = image.convert("RGB") | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| ]) | |
| img_tensor = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| enhanced = model(img_tensor) | |
| # Convert back to PIL | |
| enhanced = enhanced.squeeze(0).cpu().permute(1, 2, 0).numpy() | |
| enhanced = (enhanced * 255).clip(0, 255).astype(np.uint8) | |
| return Image.fromarray(enhanced) | |
| # ----------------------- | |
| # Gradio Interface | |
| # ----------------------- | |
| def process_image(input_img): | |
| return enhance_image(input_img, zero_dce) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## π Low Light Image Enhancement (Zero-DCE)") | |
| with gr.Row(): | |
| inp = gr.Image(type="pil", label="Upload Dark Image") | |
| out = gr.Image(type="pil", label="Enhanced Image") | |
| run_btn = gr.Button("Enhance") | |
| run_btn.click(fn=process_image, inputs=inp, outputs=out) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |