Spaces:
Runtime error
Runtime error
| # Torchvision compatibility fix for functional_tensor module | |
| # This file helps resolve compatibility issues between different torchvision versions | |
| import sys | |
| import torch | |
| import torchvision | |
| def fix_torchvision_functional_tensor(): | |
| """ | |
| Fix torchvision.transforms.functional_tensor import issue | |
| """ | |
| try: | |
| # Check if the module exists in the expected location | |
| import torchvision.transforms.functional_tensor | |
| print("torchvision.transforms.functional_tensor is available") | |
| return True | |
| except ImportError: | |
| print("torchvision.transforms.functional_tensor not found, applying compatibility fix...") | |
| try: | |
| # Create a mock functional_tensor module with the required functions | |
| import torchvision.transforms.functional as F | |
| class FunctionalTensorMock: | |
| """Mock module to replace functional_tensor""" | |
| def _get_grayscale_weights(img): | |
| """Helper to create grayscale weights based on image dimensions""" | |
| weights = torch.tensor([0.299, 0.587, 0.114], device=img.device, dtype=img.dtype) | |
| return weights.view(1, 3, 1, 1) if len(img.shape) == 4 else weights.view(3, 1, 1) | |
| def _try_import_fallback(module_names, attr_name): | |
| """Helper to try importing from multiple modules""" | |
| for module_name in module_names: | |
| try: | |
| module = __import__(module_name, fromlist=[attr_name]) | |
| if hasattr(module, attr_name): | |
| return getattr(module, attr_name) | |
| except ImportError: | |
| continue | |
| return None | |
| def rgb_to_grayscale(img, num_output_channels=1): | |
| """Convert RGB image to grayscale""" | |
| if hasattr(F, 'rgb_to_grayscale'): | |
| return F.rgb_to_grayscale(img, num_output_channels) | |
| # Fallback implementation | |
| weights = FunctionalTensorMock._get_grayscale_weights(img) | |
| grayscale = torch.sum(img * weights, dim=-3, keepdim=True) | |
| if num_output_channels == 3: | |
| repeat_dims = (1, 3, 1, 1) if len(img.shape) == 4 else (3, 1, 1) | |
| grayscale = grayscale.repeat(*repeat_dims) | |
| return grayscale | |
| def resize(img, size, interpolation=2, antialias=None): | |
| """Resize function wrapper""" | |
| # Try v2.functional first, then regular functional, then torch.nn.functional | |
| resize_func = FunctionalTensorMock._try_import_fallback([ | |
| 'torchvision.transforms.v2.functional', | |
| 'torchvision.transforms.functional' | |
| ], 'resize') | |
| if resize_func: | |
| try: | |
| return resize_func(img, size, interpolation=interpolation, antialias=antialias) | |
| except TypeError: | |
| # Fallback for older versions without antialias parameter | |
| return resize_func(img, size, interpolation=interpolation) | |
| # Final fallback using torch.nn.functional | |
| import torch.nn.functional as torch_F | |
| size = (size, size) if isinstance(size, int) else size | |
| img_input = img.unsqueeze(0) if len(img.shape) == 3 else img | |
| return torch_F.interpolate(img_input, size=size, mode='bilinear', align_corners=False) | |
| def __getattr__(self, name): | |
| """Fallback to regular functional module""" | |
| func = self._try_import_fallback([ | |
| 'torchvision.transforms.functional', | |
| 'torchvision.transforms.v2.functional' | |
| ], name) | |
| if func: | |
| return func | |
| raise AttributeError(f"'{name}' not found in functional_tensor mock") | |
| # Create the mock module instance and monkey patch | |
| sys.modules['torchvision.transforms.functional_tensor'] = FunctionalTensorMock() | |
| print("Applied compatibility fix: created functional_tensor mock module") | |
| return True | |
| except Exception as e: | |
| print(f"Failed to create functional_tensor mock: {e}") | |
| return False | |
| def apply_fix(): | |
| """Apply the torchvision compatibility fix""" | |
| print(f"Torchvision version: {torchvision.__version__}") | |
| return fix_torchvision_functional_tensor() | |
| if __name__ == "__main__": | |
| apply_fix() | |