Spaces:
Runtime error
Runtime error
Julian Bilcke
commited on
Commit
·
dfc94bc
1
Parent(s):
ac45732
upgrade finetrainers + freeze datasets
Browse files
finetrainers/data/dataset.py
CHANGED
|
@@ -970,59 +970,9 @@ def _preprocess_image(image: PIL.Image.Image) -> torch.Tensor:
|
|
| 970 |
image = image.permute(2, 0, 1).contiguous() / 127.5 - 1.0
|
| 971 |
return image
|
| 972 |
|
| 973 |
-
|
| 974 |
-
|
| 975 |
-
|
| 976 |
-
|
| 977 |
-
|
| 978 |
-
|
| 979 |
-
video = video.get_batch(list(range(len(video))))
|
| 980 |
-
video = video.permute(0, 3, 1, 2).contiguous() / 127.5 - 1.0
|
| 981 |
-
return video
|
| 982 |
-
|
| 983 |
-
# For torchvision VideoReader
|
| 984 |
-
elif 'torchvision.io.video_reader' in str(type(video)):
|
| 985 |
-
# Use the correct iteration pattern for torchvision.io.VideoReader
|
| 986 |
-
frames = []
|
| 987 |
-
try:
|
| 988 |
-
# First seek to the beginning
|
| 989 |
-
video.seek(0)
|
| 990 |
-
|
| 991 |
-
# Then collect frames by iterating
|
| 992 |
-
for _ in range(30): # Try to get a reasonable number of frames
|
| 993 |
-
try:
|
| 994 |
-
frame_dict = next(video)
|
| 995 |
-
frame = frame_dict["data"] # Extract the tensor data from the dict
|
| 996 |
-
frames.append(frame)
|
| 997 |
-
except StopIteration:
|
| 998 |
-
break
|
| 999 |
-
except Exception as e:
|
| 1000 |
-
print(f"Error iterating VideoReader: {e}")
|
| 1001 |
-
|
| 1002 |
-
if frames:
|
| 1003 |
-
# In torchvision.io.VideoReader, frames are already in [C, H, W] format
|
| 1004 |
-
# We need to stack and convert to [B, C, H, W]
|
| 1005 |
-
stacked_frames = torch.stack(frames)
|
| 1006 |
-
# Normalize to [-1, 1]
|
| 1007 |
-
stacked_frames = stacked_frames.float() / 127.5 - 1.0
|
| 1008 |
-
return stacked_frames
|
| 1009 |
-
|
| 1010 |
-
# If we couldn't get frames, create a dummy tensor
|
| 1011 |
-
print("Failed to get frames, creating dummy tensor")
|
| 1012 |
-
return torch.zeros(16, 3, 512, 768).float()
|
| 1013 |
-
|
| 1014 |
-
# For list of PIL images
|
| 1015 |
-
elif isinstance(video, list) and len(video) > 0 and hasattr(video[0], 'convert'):
|
| 1016 |
-
frames = []
|
| 1017 |
-
for img in video:
|
| 1018 |
-
img_tensor = torch.from_numpy(np.array(img.convert("RGB"))).float()
|
| 1019 |
-
frames.append(img_tensor)
|
| 1020 |
-
|
| 1021 |
-
video = torch.stack(frames)
|
| 1022 |
-
video = video.permute(0, 3, 1, 2).contiguous() / 127.5 - 1.0
|
| 1023 |
-
return video
|
| 1024 |
-
|
| 1025 |
-
# Unknown type
|
| 1026 |
-
else:
|
| 1027 |
-
print(f"Unknown video type: {type(video)}")
|
| 1028 |
-
return torch.zeros(16, 3, 512, 768).float()
|
|
|
|
| 970 |
image = image.permute(2, 0, 1).contiguous() / 127.5 - 1.0
|
| 971 |
return image
|
| 972 |
|
| 973 |
+
|
| 974 |
+
def _preprocess_video(video: decord.VideoReader) -> torch.Tensor:
|
| 975 |
+
video = video.get_batch(list(range(len(video))))
|
| 976 |
+
video = video.permute(0, 3, 1, 2).contiguous()
|
| 977 |
+
video = video.float() / 127.5 - 1.0
|
| 978 |
+
return video
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
finetrainers/trainer/sft_trainer/trainer.py
CHANGED
|
@@ -694,13 +694,14 @@ class SFTTrainer:
|
|
| 694 |
# 3. Cleanup & log artifacts
|
| 695 |
parallel_backend.wait_for_everyone()
|
| 696 |
|
|
|
|
|
|
|
|
|
|
| 697 |
# Remove all hooks that might have been added during pipeline initialization to the models
|
|
|
|
| 698 |
pipeline.remove_all_hooks()
|
| 699 |
del pipeline
|
| 700 |
-
|
| 701 |
-
utils.free_memory()
|
| 702 |
-
memory_statistics = utils.get_memory_statistics()
|
| 703 |
-
logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}")
|
| 704 |
torch.cuda.reset_peak_memory_stats(parallel_backend.device)
|
| 705 |
|
| 706 |
# Gather artifacts from all processes. We also need to flatten them since each process returns a list of artifacts.
|
|
@@ -788,7 +789,7 @@ class SFTTrainer:
|
|
| 788 |
|
| 789 |
def _init_trackers(self) -> None:
|
| 790 |
# TODO(aryan): handle multiple trackers
|
| 791 |
-
trackers = [
|
| 792 |
experiment_name = self.args.tracker_name or "finetrainers-experiment"
|
| 793 |
self.state.parallel_backend.initialize_trackers(
|
| 794 |
trackers, experiment_name=experiment_name, config=self._get_training_info(), log_dir=self.args.logging_dir
|
|
@@ -836,7 +837,6 @@ class SFTTrainer:
|
|
| 836 |
utils.synchronize_device()
|
| 837 |
|
| 838 |
def _init_pipeline(self, final_validation: bool = False) -> DiffusionPipeline:
|
| 839 |
-
parallel_backend = self.state.parallel_backend
|
| 840 |
module_names = ["text_encoder", "text_encoder_2", "text_encoder_3", "transformer", "vae"]
|
| 841 |
|
| 842 |
if not final_validation:
|
|
@@ -871,7 +871,6 @@ class SFTTrainer:
|
|
| 871 |
enable_tiling=self.args.enable_tiling,
|
| 872 |
enable_model_cpu_offload=self.args.enable_model_cpu_offload,
|
| 873 |
training=False,
|
| 874 |
-
device=parallel_backend.device,
|
| 875 |
)
|
| 876 |
|
| 877 |
# Load the LoRA weights if performing LoRA finetuning
|
|
@@ -880,7 +879,8 @@ class SFTTrainer:
|
|
| 880 |
|
| 881 |
components = {module_name: getattr(pipeline, module_name, None) for module_name in module_names}
|
| 882 |
self._set_components(components)
|
| 883 |
-
self.
|
|
|
|
| 884 |
return pipeline
|
| 885 |
|
| 886 |
def _prepare_data(
|
|
@@ -923,17 +923,12 @@ class SFTTrainer:
|
|
| 923 |
else:
|
| 924 |
logger.info("Precomputed condition & latent data exhausted. Loading & preprocessing new data.")
|
| 925 |
|
| 926 |
-
|
| 927 |
-
|
| 928 |
-
|
| 929 |
-
|
| 930 |
-
|
| 931 |
-
|
| 932 |
-
# force=True,
|
| 933 |
-
# _device=parallel_backend.device,
|
| 934 |
-
# _is_main_process=parallel_backend.is_main_process,
|
| 935 |
-
# )
|
| 936 |
-
# self._delete_components(component_names=["transformer", "unet"])
|
| 937 |
|
| 938 |
if self.args.precomputation_once:
|
| 939 |
consume_fn = preprocessor.consume_once
|
|
@@ -974,8 +969,8 @@ class SFTTrainer:
|
|
| 974 |
self._delete_components(component_names)
|
| 975 |
del latent_components, component_names, component_modules
|
| 976 |
|
| 977 |
-
|
| 978 |
-
|
| 979 |
|
| 980 |
return condition_iterator, latent_iterator
|
| 981 |
|
|
|
|
| 694 |
# 3. Cleanup & log artifacts
|
| 695 |
parallel_backend.wait_for_everyone()
|
| 696 |
|
| 697 |
+
memory_statistics = utils.get_memory_statistics()
|
| 698 |
+
logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}")
|
| 699 |
+
|
| 700 |
# Remove all hooks that might have been added during pipeline initialization to the models
|
| 701 |
+
module_names = ["text_encoder", "text_encoder_2", "text_encoder_3", "vae"]
|
| 702 |
pipeline.remove_all_hooks()
|
| 703 |
del pipeline
|
| 704 |
+
self._delete_components(module_names)
|
|
|
|
|
|
|
|
|
|
| 705 |
torch.cuda.reset_peak_memory_stats(parallel_backend.device)
|
| 706 |
|
| 707 |
# Gather artifacts from all processes. We also need to flatten them since each process returns a list of artifacts.
|
|
|
|
| 789 |
|
| 790 |
def _init_trackers(self) -> None:
|
| 791 |
# TODO(aryan): handle multiple trackers
|
| 792 |
+
trackers = [self.args.report_to]
|
| 793 |
experiment_name = self.args.tracker_name or "finetrainers-experiment"
|
| 794 |
self.state.parallel_backend.initialize_trackers(
|
| 795 |
trackers, experiment_name=experiment_name, config=self._get_training_info(), log_dir=self.args.logging_dir
|
|
|
|
| 837 |
utils.synchronize_device()
|
| 838 |
|
| 839 |
def _init_pipeline(self, final_validation: bool = False) -> DiffusionPipeline:
|
|
|
|
| 840 |
module_names = ["text_encoder", "text_encoder_2", "text_encoder_3", "transformer", "vae"]
|
| 841 |
|
| 842 |
if not final_validation:
|
|
|
|
| 871 |
enable_tiling=self.args.enable_tiling,
|
| 872 |
enable_model_cpu_offload=self.args.enable_model_cpu_offload,
|
| 873 |
training=False,
|
|
|
|
| 874 |
)
|
| 875 |
|
| 876 |
# Load the LoRA weights if performing LoRA finetuning
|
|
|
|
| 879 |
|
| 880 |
components = {module_name: getattr(pipeline, module_name, None) for module_name in module_names}
|
| 881 |
self._set_components(components)
|
| 882 |
+
if not self.args.enable_model_cpu_offload:
|
| 883 |
+
self._move_components_to_device(list(components.values()))
|
| 884 |
return pipeline
|
| 885 |
|
| 886 |
def _prepare_data(
|
|
|
|
| 923 |
else:
|
| 924 |
logger.info("Precomputed condition & latent data exhausted. Loading & preprocessing new data.")
|
| 925 |
|
| 926 |
+
parallel_backend = self.state.parallel_backend
|
| 927 |
+
if parallel_backend.world_size == 1:
|
| 928 |
+
self._move_components_to_device([self.transformer], "cpu")
|
| 929 |
+
utils.free_memory()
|
| 930 |
+
utils.synchronize_device()
|
| 931 |
+
torch.cuda.reset_peak_memory_stats(parallel_backend.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 932 |
|
| 933 |
if self.args.precomputation_once:
|
| 934 |
consume_fn = preprocessor.consume_once
|
|
|
|
| 969 |
self._delete_components(component_names)
|
| 970 |
del latent_components, component_names, component_modules
|
| 971 |
|
| 972 |
+
if parallel_backend.world_size == 1:
|
| 973 |
+
self._move_components_to_device([self.transformer])
|
| 974 |
|
| 975 |
return condition_iterator, latent_iterator
|
| 976 |
|
requirements.txt
CHANGED
|
@@ -7,6 +7,10 @@ torch==2.5.1
|
|
| 7 |
torchvision==0.20.1
|
| 8 |
torchao==0.6.1
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
huggingface_hub
|
| 11 |
hf_transfer>=0.1.8
|
| 12 |
diffusers @ git+https://github.com/huggingface/diffusers.git@main
|
|
|
|
| 7 |
torchvision==0.20.1
|
| 8 |
torchao==0.6.1
|
| 9 |
|
| 10 |
+
# datasets 3.4.0 replaces decord by torchvision
|
| 11 |
+
# let's free it for now
|
| 12 |
+
datasets==3.3.2
|
| 13 |
+
|
| 14 |
huggingface_hub
|
| 15 |
hf_transfer>=0.1.8
|
| 16 |
diffusers @ git+https://github.com/huggingface/diffusers.git@main
|
requirements_without_flash_attention.txt
CHANGED
|
@@ -8,6 +8,10 @@ torch==2.5.1
|
|
| 8 |
torchvision==0.20.1
|
| 9 |
torchao==0.6.1
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
huggingface_hub
|
| 12 |
hf_transfer>=0.1.8
|
| 13 |
diffusers @ git+https://github.com/huggingface/diffusers.git@main
|
|
|
|
| 8 |
torchvision==0.20.1
|
| 9 |
torchao==0.6.1
|
| 10 |
|
| 11 |
+
# datasets 3.4.0 replaces decord by torchvision
|
| 12 |
+
# let's free it for now
|
| 13 |
+
datasets==3.3.2
|
| 14 |
+
|
| 15 |
huggingface_hub
|
| 16 |
hf_transfer>=0.1.8
|
| 17 |
diffusers @ git+https://github.com/huggingface/diffusers.git@main
|