Aduc_sdr / tools /tensor_utils.py
euiia's picture
Update tools/tensor_utils.py
61b63f7 verified
raw
history blame
3.42 kB
# tools/tensor_utils.py
#
# Copyright (C) August 4, 2025 Carlos Rodrigues dos Santos
#
# Version: 1.0.1
#
# This module provides utility functions for tensor manipulation, specifically for
# image and video processing tasks. The functions here, such as wavelet reconstruction,
# are internalized within the ADUC-SDR framework to ensure stability and reduce
# reliance on specific external library structures.
#
# The wavelet_reconstruction code is adapted from the SeedVR project.
import torch
from torch import Tensor
from torch.nn import functional as F
from typing import Tuple
def wavelet_blur(image: Tensor, radius: int) -> Tensor:
"""
Apply wavelet blur to the input tensor.
"""
if image.ndim != 4: # Expects (B, C, H, W)
raise ValueError(f"wavelet_blur expects a 4D tensor, but got shape {image.shape}")
b, c, h, w = image.shape
# convolution kernel
kernel_vals = [
[0.0625, 0.125, 0.0625],
[0.125, 0.25, 0.125],
[0.0625, 0.125, 0.0625],
]
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
kernel = kernel[None, None] # (1, 1, 3, 3)
# repeat the kernel across all input channels for grouped convolution
kernel = kernel.repeat(c, 1, 1, 1) # (C, 1, 3, 3)
image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
# apply convolution with groups=c to process each channel independently
output = F.conv2d(image, kernel, groups=c, dilation=radius)
return output
def wavelet_decomposition(image: Tensor, levels=5) -> Tuple[Tensor, Tensor]:
"""
Apply wavelet decomposition to the input tensor.
This function returns both the high frequency and low frequency components.
"""
# Ensure tensor is 4D (B, C, H, W)
is_video_frame = image.ndim == 5 # (B, C, F, H, W)
if is_video_frame:
b, c, f, h, w = image.shape
image = image.permute(0, 2, 1, 3, 4).reshape(b * f, c, h, w)
high_freq = torch.zeros_like(image)
low_freq = image
for i in range(levels):
radius = 2 ** i
blurred = wavelet_blur(low_freq, radius)
high_freq += (low_freq - blurred)
low_freq = blurred
if is_video_frame:
high_freq = high_freq.view(b, f, c, h, w).permute(0, 2, 1, 3, 4)
low_freq = low_freq.view(b, f, c, h, w).permute(0, 2, 1, 3, 4)
return high_freq, low_freq
def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor:
"""
Applies wavelet decomposition to transfer the color/style (low-frequency components)
from a style feature to the details (high-frequency components) of a content feature.
This works for both images (4D) and videos (5D).
Args:
content_feat (Tensor): The tensor containing the structural details.
style_feat (Tensor): The tensor containing the desired color and lighting style.
Returns:
Tensor: The reconstructed tensor with content details and style colors.
"""
# calculate the wavelet decomposition of the content feature
content_high_freq, _ = wavelet_decomposition(content_feat)
# calculate the wavelet decomposition of the style feature
_, style_low_freq = wavelet_decomposition(style_feat)
# reconstruct the content feature with the style's low frequency (color/lighting)
return content_high_freq + style_low_freq