euiia commited on
Commit
61b63f7
·
verified ·
1 Parent(s): 539271a

Update tools/tensor_utils.py

Browse files
Files changed (1) hide show
  1. tools/tensor_utils.py +26 -7
tools/tensor_utils.py CHANGED
@@ -2,7 +2,7 @@
2
  #
3
  # Copyright (C) August 4, 2025 Carlos Rodrigues dos Santos
4
  #
5
- # Version: 1.0.0
6
  #
7
  # This module provides utility functions for tensor manipulation, specifically for
8
  # image and video processing tasks. The functions here, such as wavelet reconstruction,
@@ -14,11 +14,17 @@
14
  import torch
15
  from torch import Tensor
16
  from torch.nn import functional as F
 
17
 
18
  def wavelet_blur(image: Tensor, radius: int) -> Tensor:
19
  """
20
  Apply wavelet blur to the input tensor.
21
  """
 
 
 
 
 
22
  # convolution kernel
23
  kernel_vals = [
24
  [0.0625, 0.125, 0.0625],
@@ -26,13 +32,15 @@ def wavelet_blur(image: Tensor, radius: int) -> Tensor:
26
  [0.0625, 0.125, 0.0625],
27
  ]
28
  kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
29
- # add channel dimensions to the kernel to make it a 4D tensor
30
- kernel = kernel[None, None]
31
- # repeat the kernel across all input channels
32
- kernel = kernel.repeat(image.shape[1], 1, 1, 1) # Match input channels
 
33
  image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
34
- # apply convolution
35
- output = F.conv2d(image, kernel, groups=image.shape[1], dilation=radius)
 
36
  return output
37
 
38
  def wavelet_decomposition(image: Tensor, levels=5) -> Tuple[Tensor, Tensor]:
@@ -40,6 +48,12 @@ def wavelet_decomposition(image: Tensor, levels=5) -> Tuple[Tensor, Tensor]:
40
  Apply wavelet decomposition to the input tensor.
41
  This function returns both the high frequency and low frequency components.
42
  """
 
 
 
 
 
 
43
  high_freq = torch.zeros_like(image)
44
  low_freq = image
45
  for i in range(levels):
@@ -48,12 +62,17 @@ def wavelet_decomposition(image: Tensor, levels=5) -> Tuple[Tensor, Tensor]:
48
  high_freq += (low_freq - blurred)
49
  low_freq = blurred
50
 
 
 
 
 
51
  return high_freq, low_freq
52
 
53
  def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor:
54
  """
55
  Applies wavelet decomposition to transfer the color/style (low-frequency components)
56
  from a style feature to the details (high-frequency components) of a content feature.
 
57
 
58
  Args:
59
  content_feat (Tensor): The tensor containing the structural details.
 
2
  #
3
  # Copyright (C) August 4, 2025 Carlos Rodrigues dos Santos
4
  #
5
+ # Version: 1.0.1
6
  #
7
  # This module provides utility functions for tensor manipulation, specifically for
8
  # image and video processing tasks. The functions here, such as wavelet reconstruction,
 
14
  import torch
15
  from torch import Tensor
16
  from torch.nn import functional as F
17
+ from typing import Tuple
18
 
19
  def wavelet_blur(image: Tensor, radius: int) -> Tensor:
20
  """
21
  Apply wavelet blur to the input tensor.
22
  """
23
+ if image.ndim != 4: # Expects (B, C, H, W)
24
+ raise ValueError(f"wavelet_blur expects a 4D tensor, but got shape {image.shape}")
25
+
26
+ b, c, h, w = image.shape
27
+
28
  # convolution kernel
29
  kernel_vals = [
30
  [0.0625, 0.125, 0.0625],
 
32
  [0.0625, 0.125, 0.0625],
33
  ]
34
  kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
35
+ kernel = kernel[None, None] # (1, 1, 3, 3)
36
+
37
+ # repeat the kernel across all input channels for grouped convolution
38
+ kernel = kernel.repeat(c, 1, 1, 1) # (C, 1, 3, 3)
39
+
40
  image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
41
+
42
+ # apply convolution with groups=c to process each channel independently
43
+ output = F.conv2d(image, kernel, groups=c, dilation=radius)
44
  return output
45
 
46
  def wavelet_decomposition(image: Tensor, levels=5) -> Tuple[Tensor, Tensor]:
 
48
  Apply wavelet decomposition to the input tensor.
49
  This function returns both the high frequency and low frequency components.
50
  """
51
+ # Ensure tensor is 4D (B, C, H, W)
52
+ is_video_frame = image.ndim == 5 # (B, C, F, H, W)
53
+ if is_video_frame:
54
+ b, c, f, h, w = image.shape
55
+ image = image.permute(0, 2, 1, 3, 4).reshape(b * f, c, h, w)
56
+
57
  high_freq = torch.zeros_like(image)
58
  low_freq = image
59
  for i in range(levels):
 
62
  high_freq += (low_freq - blurred)
63
  low_freq = blurred
64
 
65
+ if is_video_frame:
66
+ high_freq = high_freq.view(b, f, c, h, w).permute(0, 2, 1, 3, 4)
67
+ low_freq = low_freq.view(b, f, c, h, w).permute(0, 2, 1, 3, 4)
68
+
69
  return high_freq, low_freq
70
 
71
  def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor:
72
  """
73
  Applies wavelet decomposition to transfer the color/style (low-frequency components)
74
  from a style feature to the details (high-frequency components) of a content feature.
75
+ This works for both images (4D) and videos (5D).
76
 
77
  Args:
78
  content_feat (Tensor): The tensor containing the structural details.