Spaces:
Runtime error
Runtime error
fixed some bugs with finetrainers CLI params
Browse files- vms/config.py +3 -0
- vms/services/trainer.py +25 -22
- vms/tabs/train_tab.py +13 -5
- vms/ui/video_trainer_ui.py +1 -1
vms/config.py
CHANGED
|
@@ -485,6 +485,9 @@ class TrainingConfig:
|
|
| 485 |
if self.precompute_conditions:
|
| 486 |
args.append("--precompute_conditions")
|
| 487 |
|
|
|
|
|
|
|
|
|
|
| 488 |
# Diffusion arguments
|
| 489 |
if self.flow_resolution_shifting:
|
| 490 |
args.append("--flow_resolution_shifting")
|
|
|
|
| 485 |
if self.precompute_conditions:
|
| 486 |
args.append("--precompute_conditions")
|
| 487 |
|
| 488 |
+
if hasattr(self, 'precomputation_items') and self.precomputation_items:
|
| 489 |
+
args.extend(["--precomputation_items", str(self.precomputation_items)])
|
| 490 |
+
|
| 491 |
# Diffusion arguments
|
| 492 |
if self.flow_resolution_shifting:
|
| 493 |
args.append("--flow_resolution_shifting")
|
vms/services/trainer.py
CHANGED
|
@@ -52,7 +52,10 @@ from ..utils import (
|
|
| 52 |
logger = logging.getLogger(__name__)
|
| 53 |
|
| 54 |
class TrainingService:
|
| 55 |
-
def __init__(self):
|
|
|
|
|
|
|
|
|
|
| 56 |
# State and log files
|
| 57 |
self.session_file = OUTPUT_PATH / "session.json"
|
| 58 |
self.status_file = OUTPUT_PATH / "status.json"
|
|
@@ -565,8 +568,8 @@ class TrainingService:
|
|
| 565 |
logger.info(f"{log_prefix} training with model_type={model_type}, training_type={training_type}")
|
| 566 |
|
| 567 |
# Update progress if available
|
| 568 |
-
if progress:
|
| 569 |
-
|
| 570 |
|
| 571 |
try:
|
| 572 |
# Get absolute paths - FIXED to look in project root instead of within vms directory
|
|
@@ -598,8 +601,8 @@ class TrainingService:
|
|
| 598 |
logger.info("Training data path: %s", TRAINING_PATH)
|
| 599 |
|
| 600 |
# Update progress
|
| 601 |
-
if progress:
|
| 602 |
-
|
| 603 |
|
| 604 |
videos_file, prompts_file = prepare_finetrainers_dataset()
|
| 605 |
if videos_file is None or prompts_file is None:
|
|
@@ -616,8 +619,8 @@ class TrainingService:
|
|
| 616 |
return error_msg, "No training data available"
|
| 617 |
|
| 618 |
# Update progress
|
| 619 |
-
if progress:
|
| 620 |
-
|
| 621 |
|
| 622 |
# Get preset configuration
|
| 623 |
preset = TRAINING_PRESETS[preset_name]
|
|
@@ -627,13 +630,14 @@ class TrainingService:
|
|
| 627 |
|
| 628 |
# Get the custom prompt prefix from the tabs
|
| 629 |
custom_prompt_prefix = None
|
| 630 |
-
if hasattr(self
|
| 631 |
-
if hasattr(self.app
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
|
| 635 |
-
|
| 636 |
-
|
|
|
|
| 637 |
|
| 638 |
# Create a proper dataset configuration JSON file
|
| 639 |
dataset_config_file = OUTPUT_PATH / "dataset_config.json"
|
|
@@ -725,10 +729,7 @@ class TrainingService:
|
|
| 725 |
config.flow_weighting_scheme = flow_weighting_scheme
|
| 726 |
|
| 727 |
config.lr_warmup_steps = int(lr_warmup_steps)
|
| 728 |
-
|
| 729 |
-
"--precomputation_items", str(precomputation_items)
|
| 730 |
-
])
|
| 731 |
-
|
| 732 |
# Update the NUM_GPUS variable and CUDA_VISIBLE_DEVICES
|
| 733 |
num_gpus = min(num_gpus, get_available_gpu_count())
|
| 734 |
if num_gpus <= 0:
|
|
@@ -757,6 +758,8 @@ class TrainingService:
|
|
| 757 |
config.enable_tiling = True
|
| 758 |
config.caption_dropout_p = DEFAULT_CAPTION_DROPOUT_P
|
| 759 |
|
|
|
|
|
|
|
| 760 |
validation_error = self.validate_training_config(config, model_type)
|
| 761 |
if validation_error:
|
| 762 |
error_msg = f"Configuration validation failed: {validation_error}"
|
|
@@ -843,8 +846,8 @@ class TrainingService:
|
|
| 843 |
env["FINETRAINERS_LOG_LEVEL"] = "DEBUG" # Added for better debugging
|
| 844 |
env["CUDA_VISIBLE_DEVICES"] = visible_devices
|
| 845 |
|
| 846 |
-
if progress:
|
| 847 |
-
|
| 848 |
|
| 849 |
# Start the training process
|
| 850 |
process = subprocess.Popen(
|
|
@@ -901,8 +904,8 @@ class TrainingService:
|
|
| 901 |
logger.info(success_msg)
|
| 902 |
|
| 903 |
# Final progress update - now we'll track it through the log monitor
|
| 904 |
-
if progress:
|
| 905 |
-
|
| 906 |
|
| 907 |
return success_msg, self.get_logs()
|
| 908 |
|
|
|
|
| 52 |
logger = logging.getLogger(__name__)
|
| 53 |
|
| 54 |
class TrainingService:
|
| 55 |
+
def __init__(self, app=None):
|
| 56 |
+
# Store reference to app
|
| 57 |
+
self.app = app
|
| 58 |
+
|
| 59 |
# State and log files
|
| 60 |
self.session_file = OUTPUT_PATH / "session.json"
|
| 61 |
self.status_file = OUTPUT_PATH / "status.json"
|
|
|
|
| 568 |
logger.info(f"{log_prefix} training with model_type={model_type}, training_type={training_type}")
|
| 569 |
|
| 570 |
# Update progress if available
|
| 571 |
+
#if progress:
|
| 572 |
+
# progress(0.15, desc="Setting up training configuration")
|
| 573 |
|
| 574 |
try:
|
| 575 |
# Get absolute paths - FIXED to look in project root instead of within vms directory
|
|
|
|
| 601 |
logger.info("Training data path: %s", TRAINING_PATH)
|
| 602 |
|
| 603 |
# Update progress
|
| 604 |
+
#if progress:
|
| 605 |
+
# progress(0.2, desc="Preparing training dataset")
|
| 606 |
|
| 607 |
videos_file, prompts_file = prepare_finetrainers_dataset()
|
| 608 |
if videos_file is None or prompts_file is None:
|
|
|
|
| 619 |
return error_msg, "No training data available"
|
| 620 |
|
| 621 |
# Update progress
|
| 622 |
+
#if progress:
|
| 623 |
+
# progress(0.25, desc="Creating dataset configuration")
|
| 624 |
|
| 625 |
# Get preset configuration
|
| 626 |
preset = TRAINING_PRESETS[preset_name]
|
|
|
|
| 630 |
|
| 631 |
# Get the custom prompt prefix from the tabs
|
| 632 |
custom_prompt_prefix = None
|
| 633 |
+
if hasattr(self, 'app') and self.app is not None:
|
| 634 |
+
if hasattr(self.app, 'tabs') and 'caption_tab' in self.app.tabs:
|
| 635 |
+
if hasattr(self.app.tabs['caption_tab'], 'components') and 'custom_prompt_prefix' in self.app.tabs['caption_tab'].components:
|
| 636 |
+
# Get the value and clean it
|
| 637 |
+
prefix = self.app.tabs['caption_tab'].components['custom_prompt_prefix'].value
|
| 638 |
+
if prefix:
|
| 639 |
+
# Clean the prefix - remove trailing comma, space or comma+space
|
| 640 |
+
custom_prompt_prefix = prefix.rstrip(', ')
|
| 641 |
|
| 642 |
# Create a proper dataset configuration JSON file
|
| 643 |
dataset_config_file = OUTPUT_PATH / "dataset_config.json"
|
|
|
|
| 729 |
config.flow_weighting_scheme = flow_weighting_scheme
|
| 730 |
|
| 731 |
config.lr_warmup_steps = int(lr_warmup_steps)
|
| 732 |
+
|
|
|
|
|
|
|
|
|
|
| 733 |
# Update the NUM_GPUS variable and CUDA_VISIBLE_DEVICES
|
| 734 |
num_gpus = min(num_gpus, get_available_gpu_count())
|
| 735 |
if num_gpus <= 0:
|
|
|
|
| 758 |
config.enable_tiling = True
|
| 759 |
config.caption_dropout_p = DEFAULT_CAPTION_DROPOUT_P
|
| 760 |
|
| 761 |
+
config.precomputation_items = precomputation_items
|
| 762 |
+
|
| 763 |
validation_error = self.validate_training_config(config, model_type)
|
| 764 |
if validation_error:
|
| 765 |
error_msg = f"Configuration validation failed: {validation_error}"
|
|
|
|
| 846 |
env["FINETRAINERS_LOG_LEVEL"] = "DEBUG" # Added for better debugging
|
| 847 |
env["CUDA_VISIBLE_DEVICES"] = visible_devices
|
| 848 |
|
| 849 |
+
#if progress:
|
| 850 |
+
# progress(0.9, desc="Launching training process")
|
| 851 |
|
| 852 |
# Start the training process
|
| 853 |
process = subprocess.Popen(
|
|
|
|
| 904 |
logger.info(success_msg)
|
| 905 |
|
| 906 |
# Final progress update - now we'll track it through the log monitor
|
| 907 |
+
#if progress:
|
| 908 |
+
# progress(1.0, desc="Training started successfully")
|
| 909 |
|
| 910 |
return success_msg, self.get_logs()
|
| 911 |
|
vms/tabs/train_tab.py
CHANGED
|
@@ -384,7 +384,9 @@ class TrainTab(BaseTab):
|
|
| 384 |
outputs=[self.components["status_box"]]
|
| 385 |
)
|
| 386 |
|
| 387 |
-
def handle_training_start(
|
|
|
|
|
|
|
| 388 |
"""Handle training start with proper log parser reset and checkpoint detection"""
|
| 389 |
# Safely reset log parser if it exists
|
| 390 |
if hasattr(self.app, 'log_parser') and self.app.log_parser is not None:
|
|
@@ -395,7 +397,7 @@ class TrainTab(BaseTab):
|
|
| 395 |
self.app.log_parser = TrainingLogParser()
|
| 396 |
|
| 397 |
# Initialize progress
|
| 398 |
-
progress(0, desc="Initializing training")
|
| 399 |
|
| 400 |
# Check for latest checkpoint
|
| 401 |
checkpoints = list(OUTPUT_PATH.glob("checkpoint-*"))
|
|
@@ -406,9 +408,10 @@ class TrainTab(BaseTab):
|
|
| 406 |
latest_checkpoint = max(checkpoints, key=os.path.getmtime)
|
| 407 |
resume_from = str(latest_checkpoint)
|
| 408 |
logger.info(f"Found checkpoint at {resume_from}, will resume training")
|
| 409 |
-
progress(0.05, desc=f"Resuming from checkpoint {Path(resume_from).name}")
|
| 410 |
else:
|
| 411 |
-
progress(0.05, desc="Starting new training run")
|
|
|
|
| 412 |
|
| 413 |
# Convert model_type display name to internal name
|
| 414 |
model_internal_type = MODEL_TYPES.get(model_type)
|
|
@@ -424,8 +427,13 @@ class TrainTab(BaseTab):
|
|
| 424 |
logger.error(f"Invalid training type: {training_type}")
|
| 425 |
return f"Error: Invalid training type '{training_type}'", "Training type not recognized"
|
| 426 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 427 |
# Progress update
|
| 428 |
-
progress(0.1, desc="Preparing dataset")
|
| 429 |
|
| 430 |
# Start training (it will automatically use the checkpoint if provided)
|
| 431 |
try:
|
|
|
|
| 384 |
outputs=[self.components["status_box"]]
|
| 385 |
)
|
| 386 |
|
| 387 |
+
def handle_training_start(
|
| 388 |
+
self, preset, model_type, training_type, lora_rank, lora_alpha, train_steps, batch_size, learning_rate, save_iterations, repo_id, progress=gr.Progress()
|
| 389 |
+
):
|
| 390 |
"""Handle training start with proper log parser reset and checkpoint detection"""
|
| 391 |
# Safely reset log parser if it exists
|
| 392 |
if hasattr(self.app, 'log_parser') and self.app.log_parser is not None:
|
|
|
|
| 397 |
self.app.log_parser = TrainingLogParser()
|
| 398 |
|
| 399 |
# Initialize progress
|
| 400 |
+
#progress(0, desc="Initializing training")
|
| 401 |
|
| 402 |
# Check for latest checkpoint
|
| 403 |
checkpoints = list(OUTPUT_PATH.glob("checkpoint-*"))
|
|
|
|
| 408 |
latest_checkpoint = max(checkpoints, key=os.path.getmtime)
|
| 409 |
resume_from = str(latest_checkpoint)
|
| 410 |
logger.info(f"Found checkpoint at {resume_from}, will resume training")
|
| 411 |
+
#progress(0.05, desc=f"Resuming from checkpoint {Path(resume_from).name}")
|
| 412 |
else:
|
| 413 |
+
#progress(0.05, desc="Starting new training run")
|
| 414 |
+
pass
|
| 415 |
|
| 416 |
# Convert model_type display name to internal name
|
| 417 |
model_internal_type = MODEL_TYPES.get(model_type)
|
|
|
|
| 427 |
logger.error(f"Invalid training type: {training_type}")
|
| 428 |
return f"Error: Invalid training type '{training_type}'", "Training type not recognized"
|
| 429 |
|
| 430 |
+
# Get other parameters from UI form
|
| 431 |
+
num_gpus = int(self.components["num_gpus"].value)
|
| 432 |
+
precomputation_items = int(self.components["precomputation_items"].value)
|
| 433 |
+
lr_warmup_steps = int(self.components["lr_warmup_steps"].value)
|
| 434 |
+
|
| 435 |
# Progress update
|
| 436 |
+
#progress(0.1, desc="Preparing dataset")
|
| 437 |
|
| 438 |
# Start training (it will automatically use the checkpoint if provided)
|
| 439 |
try:
|
vms/ui/video_trainer_ui.py
CHANGED
|
@@ -40,7 +40,7 @@ class VideoTrainerUI:
|
|
| 40 |
def __init__(self):
|
| 41 |
"""Initialize services and tabs"""
|
| 42 |
# Initialize core services
|
| 43 |
-
self.trainer = TrainingService()
|
| 44 |
self.splitter = SplittingService()
|
| 45 |
self.importer = ImportService()
|
| 46 |
self.captioner = CaptioningService()
|
|
|
|
| 40 |
def __init__(self):
|
| 41 |
"""Initialize services and tabs"""
|
| 42 |
# Initialize core services
|
| 43 |
+
self.trainer = TrainingService(self)
|
| 44 |
self.splitter = SplittingService()
|
| 45 |
self.importer = ImportService()
|
| 46 |
self.captioner = CaptioningService()
|