ayushshah commited on
Commit
7fc2cce
·
verified ·
1 Parent(s): e872817

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -2
app.py CHANGED
@@ -1,9 +1,104 @@
 
 
 
1
  import gradio as gr
 
 
 
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  def colorize(image):
4
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(colorize, gr.Image(), "image")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  if __name__ == "__main__":
9
  demo.launch()
 
1
+ import torch
2
+ from safetensors.torch import load_file
3
+ from huggingface_hub import hf_hub_download
4
  import gradio as gr
5
+ from PIL import Image, ImageOps
6
+ import numpy as np
7
+ from kornia.color import rgb_to_lab, lab_to_rgb
8
 
9
+
10
+ REPO_ID = "ayushshah/imagecolorization"
11
+ WEIGHTS_FILE = "model.safetensors"
12
+ ARCHITECTURE_FILE = "model.py"
13
+
14
+
15
+ # Download architecture file
16
+ hf_hub_download(
17
+ repo_id=REPO_ID,
18
+ filename=ARCHITECTURE_FILE,
19
+ local_dir=".",
20
+ local_dir_use_symlinks=False
21
+ )
22
+
23
+ # Downloading the weights
24
+ weights_path = hf_hub_download(
25
+ repo_id=REPO_ID,
26
+ filename=WEIGHTS_FILE
27
+ )
28
+
29
+
30
+ # Initialize the model
31
+ from model import UNet
32
+
33
+ model = UNet()
34
+ state_dict = load_file(weights_path)
35
+ model.load_state_dict(state_dict)
36
+ model.eval()
37
+
38
+
39
+ # Center crop and resize to 224x224
40
+ def prepare_input(image):
41
+ if image is None:
42
+ raise gr.Error("Please upload an image.")
43
+ pil_image = Image.fromarray(image)
44
+ side = min(pil_image.size)
45
+ square = ImageOps.fit(
46
+ pil_image,
47
+ (side, side),
48
+ centering=(0.5, 0.5),
49
+ )
50
+ resized = square.resize((224, 224), Image.Resampling.BICUBIC)
51
+ return np.array(resized)
52
+
53
+
54
+ # Colorize the image
55
  def colorize(image):
56
+ image = image / 255.0
57
+ img_tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float()
58
+
59
+ lab_tensor = rgb_to_lab(img_tensor)
60
+
61
+ L = lab_tensor[:, 0:1, :, :]
62
+ L_normalized = (L / 100.0)
63
+
64
+ with torch.no_grad():
65
+ ab_pred = model(L_normalized)
66
+
67
+ ab_pred = (ab_pred+1)*255.0/2-128.0
68
+ combined_lab = torch.cat([L, ab_pred], dim=1)
69
+ colorized_rgb = lab_to_rgb(combined_lab)
70
+ return colorized_rgb.squeeze().permute(1, 2, 0).numpy()
71
+
72
+
73
+ def clear_images():
74
+ return None, None
75
+
76
 
77
+ # Gradio interface
78
+ with gr.Blocks(title="Image Colorization") as demo:
79
+ gr.HTML("<h1 style='text-align: center;'>Image Colorization using UNet</h1>")
80
+ gr.Markdown(
81
+ "Upload a square image. If the image is not square, it will be center-cropped to a square image before resizing to 224x224."
82
+ )
83
+ gr.Markdown(
84
+ "The input image will also be converted to the LAB color space and the L channel will be given as input to the model."
85
+ )
86
+ with gr.Row():
87
+ with gr.Column():
88
+ input_image = gr.Image(
89
+ type="numpy",
90
+ label="Grayscale Input",
91
+ )
92
+ with gr.Row():
93
+ clear_btn = gr.Button("Clear")
94
+ submit_btn = gr.Button("Submit", variant="primary")
95
+ output_image = gr.Image(type="numpy", label="Colorized Output",image_mode='RGB')
96
+ input_image.upload(prepare_input, input_image, input_image)
97
+ submit_btn.click(colorize, input_image, output_image)
98
+ clear_btn.click(clear_images, None, [input_image, output_image])
99
+ gr.Markdown(
100
+ "This Huggingface space is running entirely on CPU. For faster performance, consider running it locally with a GPU or use Google Colab/Kaggle notebooks."
101
+ )
102
 
103
  if __name__ == "__main__":
104
  demo.launch()