Spaces:
Paused
Paused
Update models/pipeline_controlvideo.py
Browse files- models/pipeline_controlvideo.py +37 -37
models/pipeline_controlvideo.py
CHANGED
|
@@ -670,43 +670,43 @@ class ControlVideoPipeline(DiffusionPipeline):
|
|
| 670 |
return key_frame_indices, inter_frame_list
|
| 671 |
"""
|
| 672 |
def get_slide_window_indices(self, video_length, window_size):
|
| 673 |
-
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
|
| 688 |
-
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
|
| 695 |
-
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
|
| 703 |
-
|
| 704 |
-
|
| 705 |
-
|
| 706 |
-
|
| 707 |
-
|
| 708 |
-
|
| 709 |
-
|
| 710 |
@torch.no_grad()
|
| 711 |
def __call__(
|
| 712 |
self,
|
|
|
|
| 670 |
return key_frame_indices, inter_frame_list
|
| 671 |
"""
|
| 672 |
def get_slide_window_indices(self, video_length, window_size):
|
| 673 |
+
assert window_size >= 3
|
| 674 |
+
|
| 675 |
+
# Define the chunk size for processing
|
| 676 |
+
chunk_size = 4
|
| 677 |
+
|
| 678 |
+
# Calculate the number of chunks
|
| 679 |
+
num_chunks = (video_length - 1) // chunk_size + 1
|
| 680 |
+
|
| 681 |
+
# Initialize the lists to store the results
|
| 682 |
+
key_frame_indices = []
|
| 683 |
+
inter_frame_list = []
|
| 684 |
+
|
| 685 |
+
for chunk_index in range(num_chunks):
|
| 686 |
+
# Calculate the start and end indices for the current chunk
|
| 687 |
+
start_index = chunk_index * chunk_size
|
| 688 |
+
end_index = min((chunk_index + 1) * chunk_size, video_length)
|
| 689 |
+
|
| 690 |
+
# Generate key frame indices for the current chunk
|
| 691 |
+
chunk_key_frame_indices = np.arange(start_index, end_index, window_size - 1).tolist()
|
| 692 |
+
|
| 693 |
+
# Append the last index if it's not already included
|
| 694 |
+
if chunk_key_frame_indices[-1] != (end_index - 1):
|
| 695 |
+
chunk_key_frame_indices.append(end_index - 1)
|
| 696 |
+
|
| 697 |
+
# Append the key frame indices of the current chunk to the overall list
|
| 698 |
+
key_frame_indices.extend(chunk_key_frame_indices)
|
| 699 |
+
|
| 700 |
+
# Generate slices for the current chunk
|
| 701 |
+
chunk_slices = np.split(np.arange(start_index, end_index), chunk_key_frame_indices)
|
| 702 |
+
|
| 703 |
+
# Process each slice in the current chunk
|
| 704 |
+
for s in chunk_slices:
|
| 705 |
+
if len(s) < 2:
|
| 706 |
+
continue
|
| 707 |
+
inter_frame_list.append(s[1:].tolist())
|
| 708 |
+
|
| 709 |
+
return key_frame_indices, inter_frame_list
|
| 710 |
@torch.no_grad()
|
| 711 |
def __call__(
|
| 712 |
self,
|