JS6969 commited on
Commit
8c30d17
Β·
verified Β·
1 Parent(s): eeebc1f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +207 -127
app.py CHANGED
@@ -9,72 +9,187 @@ from basicsr.utils.download_util import load_file_from_url
9
  from realesrgan import RealESRGANer
10
  from realesrgan.archs.srvgg_arch import SRVGGNetCompact
11
 
12
-
 
 
13
  last_file = None
14
  img_mode = "RGBA"
15
 
16
 
17
- def realesrgan(img, model_name, denoise_strength, face_enhance, outscale):
18
- """Real-ESRGAN function to restore (and upscale) images.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  """
20
- if not img:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  return
22
 
23
- # Define model parameters
24
  if model_name == 'RealESRGAN_x4plus': # x4 RRDBNet model
25
  model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
26
  netscale = 4
27
  file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
 
28
  elif model_name == 'RealESRNet_x4plus': # x4 RRDBNet model
29
  model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
30
  netscale = 4
31
  file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
 
32
  elif model_name == 'RealESRGAN_x4plus_anime_6B': # x4 RRDBNet model with 6 blocks
33
  model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
34
  netscale = 4
35
  file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
 
36
  elif model_name == 'RealESRGAN_x2plus': # x2 RRDBNet model
37
  model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
38
  netscale = 2
39
  file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
 
40
  elif model_name == 'realesr-general-x4v3': # x4 VGG-style model (S size)
41
  model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
42
  netscale = 4
 
43
  file_url = [
44
- 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth',
45
- 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
46
  ]
47
 
48
- # Determine model paths
49
- model_path = os.path.join('weights', model_name + '.pth')
50
- if not os.path.isfile(model_path):
51
- ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
52
- for url in file_url:
53
- # model_path will be updated
54
- model_path = load_file_from_url(
55
- url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
56
-
57
- # Use dni to control the denoise strength
58
- dni_weight = None
59
- if model_name == 'realesr-general-x4v3' and denoise_strength != 1:
60
- wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')
61
- model_path = [model_path, wdn_model_path]
62
- dni_weight = [denoise_strength, 1 - denoise_strength]
63
-
64
- # Restorer Class
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  upsampler = RealESRGANer(
66
  scale=netscale,
67
  model_path=model_path,
68
  dni_weight=dni_weight,
69
  model=model,
70
- tile=0,
71
  tile_pad=10,
72
  pre_pad=10,
73
- half=False,
74
- gpu_id=None
75
  )
76
 
77
- # Use GFPGAN for face enhancement
 
78
  if face_enhance:
79
  from gfpgan import GFPGANer
80
  face_enhancer = GFPGANer(
@@ -82,140 +197,105 @@ def realesrgan(img, model_name, denoise_strength, face_enhance, outscale):
82
  upscale=outscale,
83
  arch='clean',
84
  channel_multiplier=2,
85
- bg_upsampler=upsampler)
 
86
 
87
- # Convert the input PIL image to cv2 image, so that it can be processed by realesrgan
88
  cv_img = numpy.array(img)
89
- img = cv2.cvtColor(cv_img, cv2.COLOR_RGBA2BGRA)
 
 
 
90
 
91
- # Apply restoration
92
  try:
93
- if face_enhance:
94
- _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
95
  else:
96
- output, _ = upsampler.enhance(img, outscale=outscale)
97
  except RuntimeError as error:
98
  print('Error', error)
99
- print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
 
 
 
 
 
 
100
  else:
101
- # Save restored image and return it to the output Image component
102
- if img_mode == 'RGBA': # RGBA images should be saved in png format
103
- extension = 'png'
104
- else:
105
- extension = 'jpg'
106
 
107
- out_filename = f"output_{rnd_string(8)}.{extension}"
 
108
  cv2.imwrite(out_filename, output)
109
  global last_file
110
  last_file = out_filename
111
- return out_filename
112
-
113
-
114
- def rnd_string(x):
115
- """Returns a string of 'x' random characters
116
- """
117
- characters = "abcdefghijklmnopqrstuvwxyz_0123456789"
118
- result = "".join((random.choice(characters)) for i in range(x))
119
- return result
120
-
121
-
122
- def reset():
123
- """Resets the Image components of the Gradio interface and deletes
124
- the last processed image
125
- """
126
- global last_file
127
- if last_file:
128
- print(f"Deleting {last_file} ...")
129
- os.remove(last_file)
130
- last_file = None
131
- return gr.update(value=None), gr.update(value=None)
132
-
133
-
134
- def has_transparency(img):
135
- """This function works by first checking to see if a "transparency" property is defined
136
- in the image's info -- if so, we return "True". Then, if the image is using indexed colors
137
- (such as in GIFs), it gets the index of the transparent color in the palette
138
- (img.info.get("transparency", -1)) and checks if it's used anywhere in the canvas
139
- (img.getcolors()). If the image is in RGBA mode, then presumably it has transparency in
140
- it, but it double-checks by getting the minimum and maximum values of every color channel
141
- (img.getextrema()), and checks if the alpha channel's smallest value falls below 255.
142
- https://stackoverflow.com/questions/43864101/python-pil-check-if-image-is-transparent
143
- """
144
- if img.info.get("transparency", None) is not None:
145
- return True
146
- if img.mode == "P":
147
- transparent = img.info.get("transparency", -1)
148
- for _, index in img.getcolors():
149
- if index == transparent:
150
- return True
151
- elif img.mode == "RGBA":
152
- extrema = img.getextrema()
153
- if extrema[3][0] < 255:
154
- return True
155
- return False
156
 
157
-
158
- def image_properties(img):
159
- """Returns the dimensions (width and height) and color mode of the input image and
160
- also sets the global img_mode variable to be used by the realesrgan function
161
- """
162
- global img_mode
163
- if img:
164
- if has_transparency(img):
165
- img_mode = "RGBA"
166
- else:
167
- img_mode = "RGB"
168
- properties = f"Resolution: Width: {img.size[0]}, Height: {img.size[1]} | Color Mode: {img_mode}"
169
- return properties
170
 
171
 
 
 
 
172
  def main():
173
- # Gradio Interface
174
  with gr.Blocks(title="Real-ESRGAN Gradio Demo", theme="ParityError/Interstellar") as demo:
 
175
 
176
- gr.Markdown(
177
- """ Image Upscaler
178
- """
179
- )
180
-
181
- with gr.Accordion("Upscaling option"):
182
  with gr.Row():
183
- model_name = gr.Dropdown(label="Upscaler model",
184
- choices=["RealESRGAN_x4plus", "RealESRNet_x4plus", "RealESRGAN_x4plus_anime_6B",
185
- "RealESRGAN_x2plus", "realesr-general-x4v3"],
186
- value="RealESRGAN_x4plus_anime_6B", show_label=True)
187
- denoise_strength = gr.Slider(label="Denoise Strength",
188
- minimum=0, maximum=1, step=0.1, value=0.5)
189
- outscale = gr.Slider(label="Resolution upscale",
190
- minimum=1, maximum=6, step=1, value=4, show_label=True)
191
- face_enhance = gr.Checkbox(label="Face Enhancement (GFPGAN)",
 
 
 
 
 
 
192
  )
193
-
 
 
 
 
 
 
 
 
194
  with gr.Row():
195
  with gr.Group():
196
  input_image = gr.Image(label="Input Image", type="pil", image_mode="RGBA")
197
  input_image_properties = gr.Textbox(label="Image Properties", max_lines=1)
198
  output_image = gr.Image(label="Output Image", image_mode="RGBA")
 
199
  with gr.Row():
200
  reset_btn = gr.Button("Remove images")
201
  restore_btn = gr.Button("Upscale")
202
 
203
  # Event listeners:
204
  input_image.change(fn=image_properties, inputs=input_image, outputs=input_image_properties)
205
- restore_btn.click(fn=realesrgan,
206
- inputs=[input_image, model_name, denoise_strength, face_enhance, outscale],
207
- outputs=output_image)
208
- reset_btn.click(fn=reset, inputs=[], outputs=[output_image, input_image])
209
- # reset_btn.click(None, inputs=[], outputs=[input_image], _js="() => (null)\n")
210
- # Undocumented method to clear a component's value using Javascript
211
 
212
- gr.Markdown(
213
- """
214
- """
 
215
  )
 
 
 
216
 
217
  demo.launch()
218
 
219
 
220
  if __name__ == "__main__":
221
- main()
 
9
  from realesrgan import RealESRGANer
10
  from realesrgan.archs.srvgg_arch import SRVGGNetCompact
11
 
12
+ # ────────────────────────────────────────────────────────
13
+ # Globals
14
+ # ────────────────────────────────────────────────────────
15
  last_file = None
16
  img_mode = "RGBA"
17
 
18
 
19
+ # ────────────────────────────────────────────────────────
20
+ # Utilities
21
+ # ────────────────────────────────────────────────────────
22
+ def rnd_string(x: int) -> str:
23
+ """Returns a string of 'x' random characters."""
24
+ characters = "abcdefghijklmnopqrstuvwxyz_0123456789"
25
+ result = "".join((random.choice(characters)) for _ in range(x))
26
+ return result
27
+
28
+
29
+ def reset():
30
+ """Resets the Image components and deletes the last processed image."""
31
+ global last_file
32
+ if last_file:
33
+ try:
34
+ print(f"Deleting {last_file} ...")
35
+ os.remove(last_file)
36
+ except Exception as e:
37
+ print("Delete error:", e)
38
+ last_file = None
39
+ return gr.update(value=None), gr.update(value=None)
40
+
41
+
42
+ def has_transparency(img):
43
  """
44
+ Check for transparency in a PIL image.
45
+ https://stackoverflow.com/questions/43864101/python-pil-check-if-image-is-transparent
46
+ """
47
+ if img.info.get("transparency", None) is not None:
48
+ return True
49
+ if img.mode == "P":
50
+ transparent = img.info.get("transparency", -1)
51
+ for _, index in img.getcolors():
52
+ if index == transparent:
53
+ return True
54
+ elif img.mode == "RGBA":
55
+ extrema = img.getextrema()
56
+ if extrema[3][0] < 255:
57
+ return True
58
+ return False
59
+
60
+
61
+ def image_properties(img):
62
+ """Return resolution & color mode of the input image; set global img_mode."""
63
+ global img_mode
64
+ if img:
65
+ if has_transparency(img):
66
+ img_mode = "RGBA"
67
+ else:
68
+ img_mode = "RGB"
69
+ properties = f"Resolution: Width: {img.size[0]}, Height: {img.size[1]} | Color Mode: {img_mode}"
70
+ return properties
71
+
72
+
73
+ def model_tip_text(model_name: str) -> str:
74
+ """Return human-friendly guidance for the chosen model."""
75
+ tips = {
76
+ "RealESRGAN_x4plus": (
77
+ "**RealESRGAN_x4plus (4Γ—)** β€” Best for photoreal images (portraits, landscapes). "
78
+ "Balanced detail recovery. Good default for Flux realism."
79
+ ),
80
+ "RealESRNet_x4plus": (
81
+ "**RealESRNet_x4plus (4Γ—)** β€” Softer but great on noisy/compressed sources "
82
+ "(old JPEGs, screenshots)."
83
+ ),
84
+ "RealESRGAN_x4plus_anime_6B": (
85
+ "**RealESRGAN_x4plus_anime_6B (4Γ—)** β€” For anime/illustrations/line art only. "
86
+ "Not recommended for real-life photos."
87
+ ),
88
+ "RealESRGAN_x2plus": (
89
+ "**RealESRGAN_x2plus (2Γ—)** β€” Faster, lighter 2Γ— cleanup when you don't need 4Γ—."
90
+ ),
91
+ "realesr-general-x4v3": (
92
+ "**realesr-general-x4v3 (4Γ—)** β€” Versatile mixed-content model with adjustable denoise. "
93
+ "**Denoise Strength** slider only affects this model (blends with the WDN variant). "
94
+ "Try 0.3–0.5 for slightly cleaner, sharper results."
95
+ ),
96
+ }
97
+ return tips.get(model_name, "")
98
+
99
+
100
+ # ────────────────────────────────────────────────────────
101
+ # Core upscaling
102
+ # ────────────────────────────────────────────────────────
103
+ def realesrgan(img, model_name, denoise_strength, face_enhance, outscale):
104
+ """Real-ESRGAN function to restore (and upscale) images with robust defaults."""
105
+ if img is None:
106
  return
107
 
108
+ # ----- Select backbone + weights -----
109
  if model_name == 'RealESRGAN_x4plus': # x4 RRDBNet model
110
  model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
111
  netscale = 4
112
  file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
113
+
114
  elif model_name == 'RealESRNet_x4plus': # x4 RRDBNet model
115
  model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
116
  netscale = 4
117
  file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
118
+
119
  elif model_name == 'RealESRGAN_x4plus_anime_6B': # x4 RRDBNet model with 6 blocks
120
  model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
121
  netscale = 4
122
  file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
123
+
124
  elif model_name == 'RealESRGAN_x2plus': # x2 RRDBNet model
125
  model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
126
  netscale = 2
127
  file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
128
+
129
  elif model_name == 'realesr-general-x4v3': # x4 VGG-style model (S size)
130
  model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
131
  netscale = 4
132
+ # We'll ensure BOTH base and WDN weights exist; order matters for DNI.
133
  file_url = [
134
+ 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth',
135
+ 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth'
136
  ]
137
 
138
+ else:
139
+ raise ValueError(f"Unknown model: {model_name}")
140
+
141
+ # ----- Ensure weights are on disk -----
142
+ # For the general-x4v3 case we download both; for others single file is fine.
143
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
144
+ weights_dir = os.path.join(ROOT_DIR, 'weights')
145
+ os.makedirs(weights_dir, exist_ok=True)
146
+
147
+ # Track model paths
148
+ local_paths = []
149
+ for url in file_url:
150
+ fname = os.path.basename(url)
151
+ local_path = os.path.join(weights_dir, fname)
152
+ if not os.path.isfile(local_path):
153
+ local_path = load_file_from_url(url=url, model_dir=weights_dir, progress=True)
154
+ local_paths.append(local_path)
155
+
156
+ # Default path(s)
157
+ if model_name == 'realesr-general-x4v3':
158
+ # Order: [base, wdn] then set DNI weights accordingly
159
+ base_path = os.path.join(weights_dir, 'realesr-general-x4v3.pth')
160
+ wdn_path = os.path.join(weights_dir, 'realesr-general-wdn-x4v3.pth')
161
+ model_path = [base_path, wdn_path]
162
+ denoise_strength = float(denoise_strength)
163
+ # Weight for WDN equals denoise_strength (cleaner); base gets the remainder
164
+ dni_weight = [1.0 - denoise_strength, denoise_strength]
165
+ else:
166
+ model_path = os.path.join(weights_dir, f"{model_name}.pth")
167
+ dni_weight = None
168
+
169
+ # ----- CUDA / precision / tiling -----
170
+ # Be defensive: cv2.cuda may not exist in CPU-only builds.
171
+ use_cuda = False
172
+ try:
173
+ use_cuda = hasattr(cv2, "cuda") and cv2.cuda.getCudaEnabledDeviceCount() > 0
174
+ except Exception:
175
+ use_cuda = False
176
+
177
+ gpu_id = 0 if use_cuda else None
178
+
179
  upsampler = RealESRGANer(
180
  scale=netscale,
181
  model_path=model_path,
182
  dni_weight=dni_weight,
183
  model=model,
184
+ tile=256, # Safe VRAM default; increase if you have headroom
185
  tile_pad=10,
186
  pre_pad=10,
187
+ half=bool(use_cuda), # FP16 on GPU
188
+ gpu_id=gpu_id
189
  )
190
 
191
+ # ----- Optional face enhancement -----
192
+ face_enhancer = None
193
  if face_enhance:
194
  from gfpgan import GFPGANer
195
  face_enhancer = GFPGANer(
 
197
  upscale=outscale,
198
  arch='clean',
199
  channel_multiplier=2,
200
+ bg_upsampler=upsampler
201
+ )
202
 
203
+ # ----- Convert PIL -> cv2 (handle RGB/RGBA) -----
204
  cv_img = numpy.array(img)
205
+ if cv_img.ndim == 3 and cv_img.shape[2] == 4:
206
+ cv_img = cv2.cvtColor(cv_img, cv2.COLOR_RGBA2BGRA)
207
+ else:
208
+ cv_img = cv2.cvtColor(cv_img, cv2.COLOR_RGB2BGR)
209
 
210
+ # ----- Enhance -----
211
  try:
212
+ if face_enhancer:
213
+ _, _, output = face_enhancer.enhance(cv_img, has_aligned=False, only_center_face=False, paste_back=True)
214
  else:
215
+ output, _ = upsampler.enhance(cv_img, outscale=int(outscale))
216
  except RuntimeError as error:
217
  print('Error', error)
218
+ print('Tip: If you hit CUDA OOM, try a smaller tile size (e.g., 128).')
219
+ return None
220
+
221
+ # ----- cv2 -> RGBA/RGB for Gradio, also save -----
222
+ if output.ndim == 3 and output.shape[2] == 4:
223
+ display_img = cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA)
224
+ extension = 'png'
225
  else:
226
+ display_img = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
227
+ extension = 'jpg'
 
 
 
228
 
229
+ out_filename = f"output_{rnd_string(8)}.{extension}"
230
+ try:
231
  cv2.imwrite(out_filename, output)
232
  global last_file
233
  last_file = out_filename
234
+ except Exception as e:
235
+ print("Save error:", e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
 
237
+ return display_img # ndarray so Gradio displays immediately
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
 
240
+ # ────────────────────────────────────────────────────────
241
+ # UI
242
+ # ────────────────────────────────────────────────────────
243
  def main():
 
244
  with gr.Blocks(title="Real-ESRGAN Gradio Demo", theme="ParityError/Interstellar") as demo:
245
+ gr.Markdown("## Image Upscaler")
246
 
247
+ with gr.Accordion("Upscaling options", open=True):
 
 
 
 
 
248
  with gr.Row():
249
+ model_name = gr.Dropdown(
250
+ label="Upscaler model",
251
+ choices=[
252
+ "RealESRGAN_x4plus",
253
+ "RealESRNet_x4plus",
254
+ "RealESRGAN_x4plus_anime_6B",
255
+ "RealESRGAN_x2plus",
256
+ "realesr-general-x4v3",
257
+ ],
258
+ value="RealESRGAN_x4plus", # photoreal default
259
+ show_label=True
260
+ )
261
+ denoise_strength = gr.Slider(
262
+ label="Denoise Strength (only for realesr-general-x4v3)",
263
+ minimum=0, maximum=1, step=0.1, value=0.5
264
  )
265
+ outscale = gr.Slider(
266
+ label="Resolution upscale",
267
+ minimum=1, maximum=6, step=1, value=4, show_label=True
268
+ )
269
+ face_enhance = gr.Checkbox(label="Face Enhancement (GFPGAN)", value=False)
270
+
271
+ # Model tips panel (auto-updates)
272
+ model_tips = gr.Markdown(model_tip_text("RealESRGAN_x4plus"))
273
+
274
  with gr.Row():
275
  with gr.Group():
276
  input_image = gr.Image(label="Input Image", type="pil", image_mode="RGBA")
277
  input_image_properties = gr.Textbox(label="Image Properties", max_lines=1)
278
  output_image = gr.Image(label="Output Image", image_mode="RGBA")
279
+
280
  with gr.Row():
281
  reset_btn = gr.Button("Remove images")
282
  restore_btn = gr.Button("Upscale")
283
 
284
  # Event listeners:
285
  input_image.change(fn=image_properties, inputs=input_image, outputs=input_image_properties)
286
+ model_name.change(fn=model_tip_text, inputs=model_name, outputs=model_tips)
 
 
 
 
 
287
 
288
+ restore_btn.click(
289
+ fn=realesrgan,
290
+ inputs=[input_image, model_name, denoise_strength, face_enhance, outscale],
291
+ outputs=output_image
292
  )
293
+ reset_btn.click(fn=reset, inputs=[], outputs=[output_image, input_image])
294
+
295
+ gr.Markdown("") # spacer
296
 
297
  demo.launch()
298
 
299
 
300
  if __name__ == "__main__":
301
+ main()