MogensR commited on
Commit
b2228a7
·
1 Parent(s): f1216a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +241 -70
app.py CHANGED
@@ -6,10 +6,13 @@
6
 
7
  import early_env # <<< must be FIRST
8
 
9
- import os, time, tempfile
10
- from pathlib import Path
11
  from typing import Optional, Dict, Any, Callable, Tuple
12
 
 
 
 
 
13
  # 1) CSP-safe Gradio env
14
  os.environ['GRADIO_ALLOW_FLAGGING'] = 'never'
15
  os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
@@ -19,15 +22,15 @@
19
  # 2) Gradio schema patch
20
  try:
21
  import gradio_client.utils as gc_utils
22
- orig_get_type = gc_utils.get_type
23
- def patched_get_type(schema):
24
  if not isinstance(schema, dict):
25
  if isinstance(schema, bool): return "boolean"
26
  if isinstance(schema, str): return "string"
27
  if isinstance(schema, (int, float)): return "number"
28
  return "string"
29
- return orig_get_type(schema)
30
- gc_utils.get_type = patched_get_type
31
  except Exception:
32
  pass
33
 
@@ -48,7 +51,7 @@ def patched_get_type(schema):
48
 
49
  # Background helpers
50
  from utils import PROFESSIONAL_BACKGROUNDS, validate_video_file, create_professional_background
51
- # Gradient helper (add this to utils; fallback here for preview only if missing)
52
  try:
53
  from utils import create_gradient_background
54
  except Exception:
@@ -61,20 +64,18 @@ def _to_rgb(c):
61
  return tuple(int(x) for x in c)
62
  if isinstance(c, str) and c.startswith("#") and len(c) == 7:
63
  return tuple(int(c[i:i+2], 16) for i in (1,3,5))
64
- return (255,255,255)
65
  start = _to_rgb(spec.get("start", "#222222"))
66
  end = _to_rgb(spec.get("end", "#888888"))
67
  angle = float(spec.get("angle_deg", 0))
68
- # build vertical then rotate
69
  bg = np.zeros((height, width, 3), np.uint8)
70
  for y in range(height):
71
- t = y / max(1, height-1)
72
- r = int(start[0]*(1-t) + end[0]*t)
73
- g = int(start[1]*(1-t) + end[1]*t)
74
- b = int(start[2]*(1-t) + end[2]*t)
75
- bg[y,:] = (r,g,b)
76
- # rotate to angle
77
- center = (width/2, height/2)
78
  rot = cv2.getRotationMatrix2D(center, angle, 1.0)
79
  return cv2.warpAffine(bg, rot, (width, height), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT_101)
80
 
@@ -110,20 +111,25 @@ def process(self, image, mask, **kwargs):
110
  import numpy as np
111
  import cv2
112
  from PIL import Image
 
113
 
114
  PREVIEW_W, PREVIEW_H = 640, 360 # 16:9
115
 
116
- def _hex_to_rgb(x: str) -> Tuple[int,int,int]:
117
- x = x.strip()
118
  if x.startswith("#") and len(x) == 7:
119
- return tuple(int(x[i:i+2], 16) for i in (1,3,5))
120
- return (255,255,255)
121
 
122
  def _np_to_pil(arr: np.ndarray) -> Image.Image:
123
  if arr.dtype != np.uint8:
124
- arr = arr.clip(0,255).astype(np.uint8)
125
  return Image.fromarray(arr)
126
 
 
 
 
 
127
  # ---------- main app ----------
128
  class VideoBackgroundApp:
129
  def __init__(self):
@@ -134,9 +140,12 @@ def __init__(self):
134
  self.audio_proc = AudioProcessor()
135
  self.models_loaded = False
136
  self.core_processor: Optional[CoreVideoProcessor] = None
 
 
 
137
  logger.info("VideoBackgroundApp initialized (device=%s)", self.device_mgr.get_optimal_device())
138
 
139
- def load_models(self, progress_callback: Optional[Callable]=None) -> str:
140
  logger.info("Loading models (CSP-safe)…")
141
  try:
142
  sam2, matanyone = self.model_loader.load_all_models(progress_callback=progress_callback)
@@ -176,7 +185,8 @@ def preview_preset(self, preset_key: str) -> Image.Image:
176
  return _np_to_pil(bg)
177
 
178
  def preview_upload(self, file) -> Optional[Image.Image]:
179
- if file is None: return None
 
180
  try:
181
  img = Image.open(file.name).convert("RGB")
182
  img = img.resize((PREVIEW_W, PREVIEW_H), Image.LANCZOS)
@@ -187,35 +197,173 @@ def preview_upload(self, file) -> Optional[Image.Image]:
187
 
188
  def preview_gradient(self, gtype: str, color1: str, color2: str, angle: int) -> Image.Image:
189
  spec = {
190
- "type": gtype.lower(), # "linear" or "radial" (linear in fallback)
191
- "start": _hex_to_rgb(color1),
192
- "end": _hex_to_rgb(color2),
193
- "angle_deg": float(angle),
194
  }
195
  bg = create_gradient_background(spec, PREVIEW_W, PREVIEW_H)
196
  return _np_to_pil(bg)
197
 
198
- def ai_generate_background(self, prompt: str, seed: int, width: int, height: int) -> Tuple[Optional[Image.Image], Optional[str], str]:
 
199
  """
200
- Try generating a background with diffusers; save to /tmp and return (img, path, status).
 
201
  """
 
 
 
202
  try:
203
- from diffusers import StableDiffusionPipeline
204
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  model_id = os.environ.get("BGFX_T2I_MODEL", "stabilityai/stable-diffusion-2-1")
206
- dtype = torch.float16 if torch.cuda.is_available() else torch.float32
207
- device = "cuda" if torch.cuda.is_available() else "cpu"
208
- pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  pipe = pipe.to(device)
210
- g = torch.Generator(device=device).manual_seed(int(seed)) if seed is not None else None
211
- with torch.autocast(device if device=="cuda" else "cpu"):
212
- img = pipe(prompt, height=height, width=width, guidance_scale=7.0, num_inference_steps=25, generator=g).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  tmp_path = f"/tmp/ai_bg_{int(time.time())}.png"
214
  img.save(tmp_path)
215
- return img.resize((PREVIEW_W, PREVIEW_H), Image.LANCZOS), tmp_path, f"AI background generated ✓ ({os.path.basename(tmp_path)})"
 
216
  except Exception as e:
217
- logger.warning("AI generation unavailable: %s", e)
218
- return None, None, f"AI generation unavailable: {e}"
219
 
220
  # ---- PROCESS VIDEO ----
221
  def process_video(
@@ -233,6 +381,9 @@ def process_video(
233
  if not self.models_loaded:
234
  return None, "Models not loaded yet"
235
 
 
 
 
236
  logger.info("process_video called (video=%s, source=%s, preset=%s, file=%s, grad=%s, ai=%s)",
237
  video, bg_source, preset_key, getattr(custom_bg_file, "name", None) if custom_bg_file else None,
238
  {"type": grad_type, "c1": grad_color1, "c2": grad_color2, "angle": grad_angle},
@@ -247,10 +398,9 @@ def process_video(
247
  return None, "Invalid or unreadable video file"
248
 
249
  # Build bg_config based on source
250
- bg_cfg: Dict[str, Any]
251
  src = (bg_source or "Preset").lower()
252
  if src == "upload" and custom_bg_file is not None:
253
- bg_cfg = {"custom_path": custom_bg_file.name}
254
  elif src == "gradient":
255
  bg_cfg = {
256
  "gradient": {
@@ -311,9 +461,12 @@ def create_csp_safe_gradio():
311
 
312
  # PRESET
313
  preset_choices = list(PROFESSIONAL_BACKGROUNDS.keys())
314
- preset_key = gr.Dropdown(choices=preset_choices, value=("office" if "office" in preset_choices else preset_choices[0]), label="Preset")
 
 
315
  # UPLOAD
316
  custom_bg = gr.File(label="Custom Background (Image)", file_types=["image"], visible=False)
 
317
  # GRADIENT
318
  grad_type = gr.Dropdown(choices=["Linear", "Radial"], value="Linear", label="Gradient Type", visible=False)
319
  grad_color1 = gr.ColorPicker(value="#222222", label="Start Color", visible=False)
@@ -339,48 +492,67 @@ def create_csp_safe_gradio():
339
  # ---------- UI wiring ----------
340
 
341
  # background source → show/hide controls
342
- def on_source_change(src):
343
  src = (src or "Preset").lower()
344
  return (
345
- gr.update(visible=(src=="preset")),
346
- gr.update(visible=(src=="upload")),
347
- gr.update(visible=(src=="gradient")),
348
- gr.update(visible=(src=="gradient")),
349
- gr.update(visible=(src=="gradient")),
350
- gr.update(visible=(src=="gradient")),
351
- gr.update(visible=(src=="ai generate")),
352
- gr.update(visible=(src=="ai generate")),
353
- gr.update(visible=(src=="ai generate")),
354
- gr.update(visible=(src=="ai generate")),
355
- gr.update(visible=(src=="ai generate")),
356
  )
357
  bg_source.change(
358
- fn=on_source_change,
359
  inputs=[bg_source],
360
  outputs=[preset_key, custom_bg, grad_type, grad_color1, grad_color2, grad_angle, ai_prompt, ai_seed, ai_size, ai_go, ai_status],
361
  )
362
 
363
- # live previews
364
- def preview_from_preset(key):
365
- return app.preview_preset(key)
366
- preset_key.change(fn=preview_from_preset, inputs=[preset_key], outputs=[bg_preview])
367
-
368
- def preview_from_upload(file):
369
- return app.preview_upload(file)
370
- custom_bg.change(fn=preview_from_upload, inputs=[custom_bg], outputs=[bg_preview])
 
 
 
 
 
 
 
 
 
 
 
371
 
372
- def preview_from_gradient(gt, c1, c2, ang):
373
- return app.preview_gradient(gt, c1, c2, ang)
 
374
  for comp in (grad_type, grad_color1, grad_color2, grad_angle):
375
- comp.change(fn=preview_from_gradient, inputs=[grad_type, grad_color1, grad_color2, grad_angle], outputs=[bg_preview])
 
 
 
 
376
 
377
  # AI generate
378
  def ai_generate(prompt, seed, size):
379
  try:
380
- w,h = map(int, size.split("x"))
381
  except Exception:
382
- w,h = PREVIEW_W, PREVIEW_H
383
- img, path, msg = app.ai_generate_background(prompt or "professional modern office background, neutral colors, depth of field", int(seed), w, h)
 
 
 
384
  return img, (path or None), msg
385
  ai_go.click(fn=ai_generate, inputs=[ai_prompt, ai_seed, ai_size], outputs=[bg_preview, ai_bg_path_state, ai_status])
386
 
@@ -388,7 +560,6 @@ def ai_generate(prompt, seed, size):
388
  def safe_load():
389
  msg = app.load_models()
390
  logger.info("UI: models loaded")
391
- # set initial preview (preset default)
392
  return msg, app.preview_preset(preset_key.value if hasattr(preset_key, "value") else "office")
393
  btn_load.click(fn=safe_load, outputs=[status, bg_preview])
394
 
 
6
 
7
  import early_env # <<< must be FIRST
8
 
9
+ import os, time, math
 
10
  from typing import Optional, Dict, Any, Callable, Tuple
11
 
12
+ # Prefer a writable cache on HF/Spaces
13
+ os.environ.setdefault("HF_HOME", "/tmp/hf")
14
+ os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
15
+
16
  # 1) CSP-safe Gradio env
17
  os.environ['GRADIO_ALLOW_FLAGGING'] = 'never'
18
  os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
 
22
  # 2) Gradio schema patch
23
  try:
24
  import gradio_client.utils as gc_utils
25
+ _orig_get_type = gc_utils.get_type
26
+ def _patched_get_type(schema):
27
  if not isinstance(schema, dict):
28
  if isinstance(schema, bool): return "boolean"
29
  if isinstance(schema, str): return "string"
30
  if isinstance(schema, (int, float)): return "number"
31
  return "string"
32
+ return _orig_get_type(schema)
33
+ gc_utils.get_type = _patched_get_type
34
  except Exception:
35
  pass
36
 
 
51
 
52
  # Background helpers
53
  from utils import PROFESSIONAL_BACKGROUNDS, validate_video_file, create_professional_background
54
+ # Gradient helper (add to utils; fallback here for preview only if missing)
55
  try:
56
  from utils import create_gradient_background
57
  except Exception:
 
64
  return tuple(int(x) for x in c)
65
  if isinstance(c, str) and c.startswith("#") and len(c) == 7:
66
  return tuple(int(c[i:i+2], 16) for i in (1,3,5))
67
+ return (255, 255, 255)
68
  start = _to_rgb(spec.get("start", "#222222"))
69
  end = _to_rgb(spec.get("end", "#888888"))
70
  angle = float(spec.get("angle_deg", 0))
 
71
  bg = np.zeros((height, width, 3), np.uint8)
72
  for y in range(height):
73
+ t = y / max(1, height - 1)
74
+ r = int(start[0] * (1 - t) + end[0] * t)
75
+ g = int(start[1] * (1 - t) + end[1] * t)
76
+ b = int(start[2] * (1 - t) + end[2] * t)
77
+ bg[y, :] = (r, g, b)
78
+ center = (width / 2, height / 2)
 
79
  rot = cv2.getRotationMatrix2D(center, angle, 1.0)
80
  return cv2.warpAffine(bg, rot, (width, height), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT_101)
81
 
 
111
  import numpy as np
112
  import cv2
113
  from PIL import Image
114
+ from typing import Tuple
115
 
116
  PREVIEW_W, PREVIEW_H = 640, 360 # 16:9
117
 
118
+ def _hex_to_rgb(x: str) -> Tuple[int, int, int]:
119
+ x = (x or "").strip()
120
  if x.startswith("#") and len(x) == 7:
121
+ return tuple(int(x[i:i+2], 16) for i in (1, 3, 5))
122
+ return (255, 255, 255)
123
 
124
  def _np_to_pil(arr: np.ndarray) -> Image.Image:
125
  if arr.dtype != np.uint8:
126
+ arr = arr.clip(0, 255).astype(np.uint8)
127
  return Image.fromarray(arr)
128
 
129
+ def _div8(n: int) -> int:
130
+ # Ensure sizes are multiples of 8 for SD/VAEs
131
+ return int(math.floor(max(64, n) / 8.0) * 8)
132
+
133
  # ---------- main app ----------
134
  class VideoBackgroundApp:
135
  def __init__(self):
 
140
  self.audio_proc = AudioProcessor()
141
  self.models_loaded = False
142
  self.core_processor: Optional[CoreVideoProcessor] = None
143
+ # Text-to-Image pipeline cache
144
+ self.t2i_pipe = None
145
+ self.t2i_model_id = None
146
  logger.info("VideoBackgroundApp initialized (device=%s)", self.device_mgr.get_optimal_device())
147
 
148
+ def load_models(self, progress_callback: Optional[Callable] = None) -> str:
149
  logger.info("Loading models (CSP-safe)…")
150
  try:
151
  sam2, matanyone = self.model_loader.load_all_models(progress_callback=progress_callback)
 
185
  return _np_to_pil(bg)
186
 
187
  def preview_upload(self, file) -> Optional[Image.Image]:
188
+ if file is None:
189
+ return None
190
  try:
191
  img = Image.open(file.name).convert("RGB")
192
  img = img.resize((PREVIEW_W, PREVIEW_H), Image.LANCZOS)
 
197
 
198
  def preview_gradient(self, gtype: str, color1: str, color2: str, angle: int) -> Image.Image:
199
  spec = {
200
+ "type": (gtype or "linear").lower(), # "linear" or "radial" (linear in fallback)
201
+ "start": _hex_to_rgb(color1 or "#222222"),
202
+ "end": _hex_to_rgb(color2 or "#888888"),
203
+ "angle_deg": float(angle or 0),
204
  }
205
  bg = create_gradient_background(spec, PREVIEW_W, PREVIEW_H)
206
  return _np_to_pil(bg)
207
 
208
+ # ---- AI BG: lazy-load + reuse pipe ----
209
+ def _ensure_t2i(self):
210
  """
211
+ Choose and load a text-to-image pipeline once, with memory-efficient settings.
212
+ Returns (pipe, model_id, msg)
213
  """
214
+ if self.t2i_pipe is not None:
215
+ return self.t2i_pipe, self.t2i_model_id, "AI generator ready"
216
+
217
  try:
 
218
  import torch
219
+ from diffusers import StableDiffusionPipeline, AutoPipelineForText2Image
220
+ except Exception as e:
221
+ return None, None, f"AI generation unavailable (missing deps): {e}"
222
+
223
+ # Heuristic: prefer fast/light models when VRAM is small
224
+ token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
225
+ device = "cuda" if getattr(torch, "cuda", None) and torch.cuda.is_available() else "cpu"
226
+
227
+ vram_gb = None
228
+ try:
229
+ vram_gb = self.device_mgr.get_device_memory_gb()
230
+ except Exception:
231
+ pass
232
+
233
+ # Prefer SD-Turbo if GPU and small VRAM; SDXL-Turbo if large VRAM; fallback to SD 2.1 on CPU
234
+ if device == "cuda":
235
+ if vram_gb and vram_gb >= 12:
236
+ model_id = os.environ.get("BGFX_T2I_MODEL", "stabilityai/sdxl-turbo")
237
+ else:
238
+ model_id = os.environ.get("BGFX_T2I_MODEL", "stabilityai/sd-turbo")
239
+ else:
240
+ # CPU-friendly (still heavy): classic SD 2.1
241
  model_id = os.environ.get("BGFX_T2I_MODEL", "stabilityai/stable-diffusion-2-1")
242
+
243
+ logger.info(f"Loading text-to-image model: {model_id} (device={device}, vram={vram_gb} GB)")
244
+
245
+ dtype = torch.float16 if device == "cuda" else torch.float32
246
+
247
+ pipe = None
248
+ err = None
249
+ try:
250
+ # Newer unified API handles sd-turbo and sdxl-turbo too
251
+ pipe = AutoPipelineForText2Image.from_pretrained(
252
+ model_id,
253
+ torch_dtype=dtype,
254
+ use_safetensors=True,
255
+ token=token
256
+ )
257
+ except Exception as e1:
258
+ err = e1
259
+ try:
260
+ # Fallback to classic pipeline (works for sd/stable-diffusion-2-1)
261
+ pipe = StableDiffusionPipeline.from_pretrained(
262
+ model_id,
263
+ torch_dtype=dtype,
264
+ use_safetensors=True,
265
+ safety_checker=None, # disable to avoid false positives for office backgrounds
266
+ feature_extractor=None,
267
+ use_auth_token=token # legacy name
268
+ )
269
+ except Exception as e2:
270
+ return None, None, f"AI model load failed: {e1} / {e2}"
271
+
272
+ # Memory/perf knobs
273
+ try:
274
+ pipe.set_progress_bar_config(disable=True)
275
+ except Exception:
276
+ pass
277
+ try:
278
+ pipe.enable_attention_slicing()
279
+ except Exception:
280
+ pass
281
+ try:
282
+ pipe.enable_vae_slicing()
283
+ except Exception:
284
+ pass
285
+ if device == "cuda":
286
+ try:
287
+ pipe.enable_xformers_memory_efficient_attention()
288
+ except Exception:
289
+ pass
290
  pipe = pipe.to(device)
291
+ else:
292
+ # If accelerate is present, offload module-wise to save RAM
293
+ try:
294
+ pipe.enable_sequential_cpu_offload()
295
+ except Exception:
296
+ pass
297
+
298
+ self.t2i_pipe = pipe
299
+ self.t2i_model_id = model_id
300
+ return pipe, model_id, f"AI model loaded: {model_id}"
301
+
302
+ def ai_generate_background(self, prompt: str, seed: int, width: int, height: int) -> Tuple[Optional[Image.Image], Optional[str], str]:
303
+ """
304
+ Generate a background and save to /tmp. Returns (preview_img, path, status).
305
+ """
306
+ pipe, model_id, msg = self._ensure_t2i()
307
+ if pipe is None:
308
+ logger.warning(msg)
309
+ return None, None, msg
310
+
311
+ # Ensure sane, divisible-by-8 sizes
312
+ w = _div8(int(width)) if width else PREVIEW_W
313
+ h = _div8(int(height)) if height else PREVIEW_H
314
+ w = max(256, min(w, 1536))
315
+ h = max(256, min(h, 1536))
316
+
317
+ # Reasonable defaults for office-like backgrounds
318
+ prompt = (prompt or "professional modern office background, neutral colors, soft depth of field, clean, minimal, photorealistic")
319
+ negative = "text, watermark, logo, people, person, artifact, noisy, blurry"
320
+
321
+ # Seed & inference
322
+ try:
323
+ import torch
324
+ g = None
325
+ device = "cuda" if getattr(torch, "cuda", None) and torch.cuda.is_available() else "cpu"
326
+ try:
327
+ g = torch.Generator(device=device).manual_seed(int(seed)) if seed is not None else None
328
+ except Exception:
329
+ g = None
330
+
331
+ # steps: turbo likes very low steps; classic SD needs more
332
+ steps = 4 if ("turbo" in (model_id or "").lower()) else 25
333
+ guidance = 1.0 if ("turbo" in (model_id or "").lower()) else 7.0
334
+
335
+ with torch.inference_mode():
336
+ if device == "cuda":
337
+ # autocast for fp16
338
+ with torch.autocast("cuda"):
339
+ out = pipe(
340
+ prompt=prompt,
341
+ negative_prompt=negative,
342
+ height=h,
343
+ width=w,
344
+ guidance_scale=guidance,
345
+ num_inference_steps=steps,
346
+ generator=g
347
+ )
348
+ else:
349
+ out = pipe(
350
+ prompt=prompt,
351
+ negative_prompt=negative,
352
+ height=h,
353
+ width=w,
354
+ guidance_scale=guidance,
355
+ num_inference_steps=steps,
356
+ generator=g
357
+ )
358
+ img = out.images[0]
359
+
360
  tmp_path = f"/tmp/ai_bg_{int(time.time())}.png"
361
  img.save(tmp_path)
362
+ # Return preview-sized display to keep UI snappy
363
+ return img.resize((PREVIEW_W, PREVIEW_H), Image.LANCZOS), tmp_path, f"{msg} • Generated {w}x{h}"
364
  except Exception as e:
365
+ logger.exception("AI generation error")
366
+ return None, None, f"AI generation failed: {e}"
367
 
368
  # ---- PROCESS VIDEO ----
369
  def process_video(
 
381
  if not self.models_loaded:
382
  return None, "Models not loaded yet"
383
 
384
+ if not video:
385
+ return None, "Please upload a video first."
386
+
387
  logger.info("process_video called (video=%s, source=%s, preset=%s, file=%s, grad=%s, ai=%s)",
388
  video, bg_source, preset_key, getattr(custom_bg_file, "name", None) if custom_bg_file else None,
389
  {"type": grad_type, "c1": grad_color1, "c2": grad_color2, "angle": grad_angle},
 
398
  return None, "Invalid or unreadable video file"
399
 
400
  # Build bg_config based on source
 
401
  src = (bg_source or "Preset").lower()
402
  if src == "upload" and custom_bg_file is not None:
403
+ bg_cfg: Dict[str, Any] = {"custom_path": custom_bg_file.name}
404
  elif src == "gradient":
405
  bg_cfg = {
406
  "gradient": {
 
461
 
462
  # PRESET
463
  preset_choices = list(PROFESSIONAL_BACKGROUNDS.keys())
464
+ default_preset = "office" if "office" in preset_choices else (preset_choices[0] if preset_choices else "office")
465
+ preset_key = gr.Dropdown(choices=preset_choices, value=default_preset, label="Preset")
466
+
467
  # UPLOAD
468
  custom_bg = gr.File(label="Custom Background (Image)", file_types=["image"], visible=False)
469
+
470
  # GRADIENT
471
  grad_type = gr.Dropdown(choices=["Linear", "Radial"], value="Linear", label="Gradient Type", visible=False)
472
  grad_color1 = gr.ColorPicker(value="#222222", label="Start Color", visible=False)
 
492
  # ---------- UI wiring ----------
493
 
494
  # background source → show/hide controls
495
+ def on_source_toggle(src):
496
  src = (src or "Preset").lower()
497
  return (
498
+ gr.update(visible=(src == "preset")),
499
+ gr.update(visible=(src == "upload")),
500
+ gr.update(visible=(src == "gradient")),
501
+ gr.update(visible=(src == "gradient")),
502
+ gr.update(visible=(src == "gradient")),
503
+ gr.update(visible=(src == "gradient")),
504
+ gr.update(visible=(src == "ai generate")),
505
+ gr.update(visible=(src == "ai generate")),
506
+ gr.update(visible=(src == "ai generate")),
507
+ gr.update(visible=(src == "ai generate")),
508
+ gr.update(visible=(src == "ai generate")),
509
  )
510
  bg_source.change(
511
+ fn=on_source_toggle,
512
  inputs=[bg_source],
513
  outputs=[preset_key, custom_bg, grad_type, grad_color1, grad_color2, grad_angle, ai_prompt, ai_seed, ai_size, ai_go, ai_status],
514
  )
515
 
516
+ # Clear any previous AI image path when switching source (avoids stale AI background)
517
+ def _clear_ai_state(_):
518
+ return None
519
+ bg_source.change(fn=_clear_ai_state, inputs=[bg_source], outputs=[ai_bg_path_state])
520
+
521
+ # When source changes, also refresh preview based on visible controls
522
+ def on_source_preview(src, pkey, gt, c1, c2, ang):
523
+ src_l = (src or "Preset").lower()
524
+ if src_l == "preset":
525
+ return app.preview_preset(pkey)
526
+ elif src_l == "gradient":
527
+ return app.preview_gradient(gt, c1, c2, ang)
528
+ # For upload/AI we keep whatever the component change handler sets (don’t overwrite)
529
+ return gr.update() # no-op
530
+ bg_source.change(
531
+ fn=on_source_preview,
532
+ inputs=[bg_source, preset_key, grad_type, grad_color1, grad_color2, grad_angle],
533
+ outputs=[bg_preview]
534
+ )
535
 
536
+ # live previews
537
+ preset_key.change(fn=lambda k: app.preview_preset(k), inputs=[preset_key], outputs=[bg_preview])
538
+ custom_bg.change(fn=lambda f: app.preview_upload(f), inputs=[custom_bg], outputs=[bg_preview])
539
  for comp in (grad_type, grad_color1, grad_color2, grad_angle):
540
+ comp.change(
541
+ fn=lambda gt, c1, c2, ang: app.preview_gradient(gt, c1, c2, ang),
542
+ inputs=[grad_type, grad_color1, grad_color2, grad_angle],
543
+ outputs=[bg_preview],
544
+ )
545
 
546
  # AI generate
547
  def ai_generate(prompt, seed, size):
548
  try:
549
+ w, h = map(int, size.split("x"))
550
  except Exception:
551
+ w, h = PREVIEW_W, PREVIEW_H
552
+ img, path, msg = app.ai_generate_background(
553
+ prompt or "professional modern office background, neutral colors, depth of field",
554
+ int(seed), w, h
555
+ )
556
  return img, (path or None), msg
557
  ai_go.click(fn=ai_generate, inputs=[ai_prompt, ai_seed, ai_size], outputs=[bg_preview, ai_bg_path_state, ai_status])
558
 
 
560
  def safe_load():
561
  msg = app.load_models()
562
  logger.info("UI: models loaded")
 
563
  return msg, app.preview_preset(preset_key.value if hasattr(preset_key, "value") else "office")
564
  btn_load.click(fn=safe_load, outputs=[status, bg_preview])
565