File size: 3,415 Bytes
7974f2d
 
 
 
61b63f7
7974f2d
 
 
 
 
 
 
 
 
 
 
61b63f7
7974f2d
 
 
 
 
61b63f7
 
 
 
 
7974f2d
 
 
 
 
 
 
61b63f7
 
 
 
 
7974f2d
61b63f7
 
 
7974f2d
 
 
 
 
 
 
61b63f7
 
 
 
 
 
7974f2d
 
 
 
 
 
 
 
61b63f7
 
 
 
7974f2d
 
 
 
 
 
61b63f7
7974f2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# tools/tensor_utils.py
#
# Copyright (C) August 4, 2025  Carlos Rodrigues dos Santos
#
# Version: 1.0.1
#
# This module provides utility functions for tensor manipulation, specifically for
# image and video processing tasks. The functions here, such as wavelet reconstruction,
# are internalized within the ADUC-SDR framework to ensure stability and reduce
# reliance on specific external library structures.
#
# The wavelet_reconstruction code is adapted from the SeedVR project.

import torch
from torch import Tensor
from torch.nn import functional as F
from typing import Tuple

def wavelet_blur(image: Tensor, radius: int) -> Tensor:
    """
    Apply wavelet blur to the input tensor.
    """
    if image.ndim != 4: # Expects (B, C, H, W)
        raise ValueError(f"wavelet_blur expects a 4D tensor, but got shape {image.shape}")
        
    b, c, h, w = image.shape
    
    # convolution kernel
    kernel_vals = [
        [0.0625, 0.125, 0.0625],
        [0.125, 0.25, 0.125],
        [0.0625, 0.125, 0.0625],
    ]
    kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
    kernel = kernel[None, None] # (1, 1, 3, 3)
    
    # repeat the kernel across all input channels for grouped convolution
    kernel = kernel.repeat(c, 1, 1, 1) # (C, 1, 3, 3)
    
    image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
    
    # apply convolution with groups=c to process each channel independently
    output = F.conv2d(image, kernel, groups=c, dilation=radius)
    return output

def wavelet_decomposition(image: Tensor, levels=5) -> Tuple[Tensor, Tensor]:
    """
    Apply wavelet decomposition to the input tensor.
    This function returns both the high frequency and low frequency components.
    """
    # Ensure tensor is 4D (B, C, H, W)
    is_video_frame = image.ndim == 5 # (B, C, F, H, W)
    if is_video_frame:
        b, c, f, h, w = image.shape
        image = image.permute(0, 2, 1, 3, 4).reshape(b * f, c, h, w)

    high_freq = torch.zeros_like(image)
    low_freq = image
    for i in range(levels):
        radius = 2 ** i
        blurred = wavelet_blur(low_freq, radius)
        high_freq += (low_freq - blurred)
        low_freq = blurred

    if is_video_frame:
        high_freq = high_freq.view(b, f, c, h, w).permute(0, 2, 1, 3, 4)
        low_freq = low_freq.view(b, f, c, h, w).permute(0, 2, 1, 3, 4)
        
    return high_freq, low_freq

def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor:
    """
    Applies wavelet decomposition to transfer the color/style (low-frequency components)
    from a style feature to the details (high-frequency components) of a content feature.
    This works for both images (4D) and videos (5D).

    Args:
        content_feat (Tensor): The tensor containing the structural details.
        style_feat (Tensor): The tensor containing the desired color and lighting style.
    
    Returns:
        Tensor: The reconstructed tensor with content details and style colors.
    """
    # calculate the wavelet decomposition of the content feature
    content_high_freq, _ = wavelet_decomposition(content_feat)
    
    # calculate the wavelet decomposition of the style feature
    _, style_low_freq = wavelet_decomposition(style_feat)
    
    # reconstruct the content feature with the style's low frequency (color/lighting)
    return content_high_freq + style_low_freq