wuhp commited on
Commit
2716e64
·
verified ·
1 Parent(s): 2933d9b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -19
app.py CHANGED
@@ -459,11 +459,11 @@ def _install_supervisely_logger_shim():
459
  """))
460
  return str(root)
461
 
462
- # ---- NEW: robust sitecustomize shim with lazy import hook --------------------
463
  def _install_workspace_shim_v3(dest_dir: str, module_default: str = "rtdetrv2_pytorch.src"):
464
  """
465
  sitecustomize shim that:
466
- - patches workspace.create for positional/keyword cfg,
467
  - ensures cfg is a dict,
468
  - injects cfg['_pymodule'] as a *module object*,
469
  even if the target module is imported after sitecustomize runs.
@@ -494,21 +494,37 @@ def _patch_ws(ws_mod):
494
  if getattr(ws_mod, "__rolo_patched__", False):
495
  return
496
  _orig_create = ws_mod.create
 
 
497
  def create(name, *args, **kwargs):
498
- if args:
499
- args = list(args)
500
- cfg = args[0]
501
- else:
502
- cfg = kwargs.get("cfg", None)
503
- if not isinstance(cfg, dict):
504
- cfg = {} if cfg is None else dict(cfg)
505
  _ensure_pymodule_object(cfg)
506
- if args:
507
- args[0] = cfg
508
- args = tuple(args)
509
- else:
510
- kwargs["cfg"] = cfg
511
- return _orig_create(name, *args, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
512
  ws_mod.create = create
513
  ws_mod.__rolo_patched__ = True
514
 
@@ -913,8 +929,6 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
913
  if not base_cfg:
914
  raise gr.Error("Could not find a matching RT-DETRv2 config in the repo (S/M/M*/L/X).")
915
 
916
- # No longer patch files on disk; we use a runtime shim instead.
917
-
918
  data_yaml = os.path.join(dataset_path, "data.yaml")
919
  with open(data_yaml, "r", encoding="utf-8") as f:
920
  dy = yaml.safe_load(f)
@@ -1133,7 +1147,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
1133
  gr.Markdown("Pick RT-DETRv2 model, set hyper-params, press Start.")
1134
  with gr.Row():
1135
  with gr.Column(scale=1):
1136
- model_dd = gr.Dropdown(choices=[k for k,_ in MODEL_CHOICES], value=DEFAULT_MODEL_KEY,
1137
  label="Model (RT-DETRv2)")
1138
  run_name_tb = gr.Textbox(label="Run Name", value="rtdetrv2_run_1")
1139
  epochs_sl = gr.Slider(1, 500, 100, step=1, label="Epochs")
@@ -1182,4 +1196,4 @@ if __name__ == "__main__":
1182
  logging.info(f"Startup check — training script at: {ts}")
1183
  except Exception as e:
1184
  logging.warning(f"Startup training-script check failed: {e}")
1185
- app.launch(debug=True)
 
459
  """))
460
  return str(root)
461
 
462
+ # ---- [UPDATED] robust sitecustomize shim with lazy import hook --------------------
463
  def _install_workspace_shim_v3(dest_dir: str, module_default: str = "rtdetrv2_pytorch.src"):
464
  """
465
  sitecustomize shim that:
466
+ - patches workspace.create to handle dict-based component definitions,
467
  - ensures cfg is a dict,
468
  - injects cfg['_pymodule'] as a *module object*,
469
  even if the target module is imported after sitecustomize runs.
 
494
  if getattr(ws_mod, "__rolo_patched__", False):
495
  return
496
  _orig_create = ws_mod.create
497
+
498
+ # NEW, FIXED create function
499
  def create(name, *args, **kwargs):
500
+ # Unify all config sources into one dictionary. The main config is often the second arg.
501
+ cfg = {}
502
+ if args and isinstance(args[0], dict):
503
+ cfg.update(args[0])
504
+ if 'cfg' in kwargs and isinstance(kwargs['cfg'], dict):
505
+ cfg.update(kwargs['cfg'])
506
+
507
  _ensure_pymodule_object(cfg)
508
+
509
+ # The core of the fix: handle when the component itself is passed as a dict.
510
+ # This is what happens when the library tries to create the model.
511
+ if isinstance(name, dict):
512
+ component_params = name.copy()
513
+ type_name = component_params.pop('type', None)
514
+ if type_name is None:
515
+ # If no 'type' key, we can't proceed. Fall back to original to get the original error.
516
+ return _orig_create(name, *args, **kwargs)
517
+
518
+ # Merge the component's own parameters (like num_classes) into the main config.
519
+ cfg.update(component_params)
520
+
521
+ # Now, call the original `create` function the way it expects:
522
+ # with the component name as a string, and the full config.
523
+ return _orig_create(type_name, cfg=cfg)
524
+
525
+ # If 'name' was already a string (the normal case for solvers, etc.), proceed as expected.
526
+ return _orig_create(name, cfg=cfg)
527
+
528
  ws_mod.create = create
529
  ws_mod.__rolo_patched__ = True
530
 
 
929
  if not base_cfg:
930
  raise gr.Error("Could not find a matching RT-DETRv2 config in the repo (S/M/M*/L/X).")
931
 
 
 
932
  data_yaml = os.path.join(dataset_path, "data.yaml")
933
  with open(data_yaml, "r", encoding="utf-8") as f:
934
  dy = yaml.safe_load(f)
 
1147
  gr.Markdown("Pick RT-DETRv2 model, set hyper-params, press Start.")
1148
  with gr.Row():
1149
  with gr.Column(scale=1):
1150
+ model_dd = gr.Dropdown(choices=[(label, k) for k, label in MODEL_CHOICES], value=DEFAULT_MODEL_KEY,
1151
  label="Model (RT-DETRv2)")
1152
  run_name_tb = gr.Textbox(label="Run Name", value="rtdetrv2_run_1")
1153
  epochs_sl = gr.Slider(1, 500, 100, step=1, label="Epochs")
 
1196
  logging.info(f"Startup check — training script at: {ts}")
1197
  except Exception as e:
1198
  logging.warning(f"Startup training-script check failed: {e}")
1199
+ app.launch(debug=True)