Update tools/tensor_utils.py
Browse files- tools/tensor_utils.py +26 -7
tools/tensor_utils.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
#
|
| 3 |
# Copyright (C) August 4, 2025 Carlos Rodrigues dos Santos
|
| 4 |
#
|
| 5 |
-
# Version: 1.0.
|
| 6 |
#
|
| 7 |
# This module provides utility functions for tensor manipulation, specifically for
|
| 8 |
# image and video processing tasks. The functions here, such as wavelet reconstruction,
|
|
@@ -14,11 +14,17 @@
|
|
| 14 |
import torch
|
| 15 |
from torch import Tensor
|
| 16 |
from torch.nn import functional as F
|
|
|
|
| 17 |
|
| 18 |
def wavelet_blur(image: Tensor, radius: int) -> Tensor:
|
| 19 |
"""
|
| 20 |
Apply wavelet blur to the input tensor.
|
| 21 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
# convolution kernel
|
| 23 |
kernel_vals = [
|
| 24 |
[0.0625, 0.125, 0.0625],
|
|
@@ -26,13 +32,15 @@ def wavelet_blur(image: Tensor, radius: int) -> Tensor:
|
|
| 26 |
[0.0625, 0.125, 0.0625],
|
| 27 |
]
|
| 28 |
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
# repeat the kernel across all input channels
|
| 32 |
-
kernel = kernel.repeat(
|
|
|
|
| 33 |
image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
|
| 34 |
-
|
| 35 |
-
|
|
|
|
| 36 |
return output
|
| 37 |
|
| 38 |
def wavelet_decomposition(image: Tensor, levels=5) -> Tuple[Tensor, Tensor]:
|
|
@@ -40,6 +48,12 @@ def wavelet_decomposition(image: Tensor, levels=5) -> Tuple[Tensor, Tensor]:
|
|
| 40 |
Apply wavelet decomposition to the input tensor.
|
| 41 |
This function returns both the high frequency and low frequency components.
|
| 42 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
high_freq = torch.zeros_like(image)
|
| 44 |
low_freq = image
|
| 45 |
for i in range(levels):
|
|
@@ -48,12 +62,17 @@ def wavelet_decomposition(image: Tensor, levels=5) -> Tuple[Tensor, Tensor]:
|
|
| 48 |
high_freq += (low_freq - blurred)
|
| 49 |
low_freq = blurred
|
| 50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
return high_freq, low_freq
|
| 52 |
|
| 53 |
def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor:
|
| 54 |
"""
|
| 55 |
Applies wavelet decomposition to transfer the color/style (low-frequency components)
|
| 56 |
from a style feature to the details (high-frequency components) of a content feature.
|
|
|
|
| 57 |
|
| 58 |
Args:
|
| 59 |
content_feat (Tensor): The tensor containing the structural details.
|
|
|
|
| 2 |
#
|
| 3 |
# Copyright (C) August 4, 2025 Carlos Rodrigues dos Santos
|
| 4 |
#
|
| 5 |
+
# Version: 1.0.1
|
| 6 |
#
|
| 7 |
# This module provides utility functions for tensor manipulation, specifically for
|
| 8 |
# image and video processing tasks. The functions here, such as wavelet reconstruction,
|
|
|
|
| 14 |
import torch
|
| 15 |
from torch import Tensor
|
| 16 |
from torch.nn import functional as F
|
| 17 |
+
from typing import Tuple
|
| 18 |
|
| 19 |
def wavelet_blur(image: Tensor, radius: int) -> Tensor:
|
| 20 |
"""
|
| 21 |
Apply wavelet blur to the input tensor.
|
| 22 |
"""
|
| 23 |
+
if image.ndim != 4: # Expects (B, C, H, W)
|
| 24 |
+
raise ValueError(f"wavelet_blur expects a 4D tensor, but got shape {image.shape}")
|
| 25 |
+
|
| 26 |
+
b, c, h, w = image.shape
|
| 27 |
+
|
| 28 |
# convolution kernel
|
| 29 |
kernel_vals = [
|
| 30 |
[0.0625, 0.125, 0.0625],
|
|
|
|
| 32 |
[0.0625, 0.125, 0.0625],
|
| 33 |
]
|
| 34 |
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
|
| 35 |
+
kernel = kernel[None, None] # (1, 1, 3, 3)
|
| 36 |
+
|
| 37 |
+
# repeat the kernel across all input channels for grouped convolution
|
| 38 |
+
kernel = kernel.repeat(c, 1, 1, 1) # (C, 1, 3, 3)
|
| 39 |
+
|
| 40 |
image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
|
| 41 |
+
|
| 42 |
+
# apply convolution with groups=c to process each channel independently
|
| 43 |
+
output = F.conv2d(image, kernel, groups=c, dilation=radius)
|
| 44 |
return output
|
| 45 |
|
| 46 |
def wavelet_decomposition(image: Tensor, levels=5) -> Tuple[Tensor, Tensor]:
|
|
|
|
| 48 |
Apply wavelet decomposition to the input tensor.
|
| 49 |
This function returns both the high frequency and low frequency components.
|
| 50 |
"""
|
| 51 |
+
# Ensure tensor is 4D (B, C, H, W)
|
| 52 |
+
is_video_frame = image.ndim == 5 # (B, C, F, H, W)
|
| 53 |
+
if is_video_frame:
|
| 54 |
+
b, c, f, h, w = image.shape
|
| 55 |
+
image = image.permute(0, 2, 1, 3, 4).reshape(b * f, c, h, w)
|
| 56 |
+
|
| 57 |
high_freq = torch.zeros_like(image)
|
| 58 |
low_freq = image
|
| 59 |
for i in range(levels):
|
|
|
|
| 62 |
high_freq += (low_freq - blurred)
|
| 63 |
low_freq = blurred
|
| 64 |
|
| 65 |
+
if is_video_frame:
|
| 66 |
+
high_freq = high_freq.view(b, f, c, h, w).permute(0, 2, 1, 3, 4)
|
| 67 |
+
low_freq = low_freq.view(b, f, c, h, w).permute(0, 2, 1, 3, 4)
|
| 68 |
+
|
| 69 |
return high_freq, low_freq
|
| 70 |
|
| 71 |
def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor:
|
| 72 |
"""
|
| 73 |
Applies wavelet decomposition to transfer the color/style (low-frequency components)
|
| 74 |
from a style feature to the details (high-frequency components) of a content feature.
|
| 75 |
+
This works for both images (4D) and videos (5D).
|
| 76 |
|
| 77 |
Args:
|
| 78 |
content_feat (Tensor): The tensor containing the structural details.
|