Spaces:
Sleeping
Sleeping
| import os | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import Dataset | |
| import torch.nn.functional as F | |
| import torch.multiprocessing as mp | |
| class VQGANDataset(Dataset): | |
| def __init__(self, root_dir: str, file_paths: str, internal_resolution: int): | |
| super().__init__() | |
| self.root_dir = root_dir | |
| self.file_paths = file_paths | |
| self.internal_resolution = internal_resolution | |
| def __len__(self): | |
| return len(self.file_paths) | |
| def __getitem__(self, idx: int): | |
| filename = os.path.join(self.root_dir, self.file_paths[idx]) | |
| try: | |
| numpy_file = np.load(filename) | |
| torch_np = torch.from_numpy(numpy_file) | |
| torch_np = torch_np.unsqueeze(0).unsqueeze(0).float() # Convert to float and move to appropriate device | |
| interpolated_data = F.interpolate(input=torch_np, size=(self.internal_resolution, self.internal_resolution, self.internal_resolution), mode='trilinear') | |
| # Apply tanh and log operations | |
| interpolated_data_tanh = torch.tanh(interpolated_data) | |
| interpolated_data_log = torch.log(interpolated_data + 1).squeeze(0) # Adding 1 to avoid log(0) | |
| return interpolated_data_log | |
| except Exception as e: | |
| print(f"Error loading file '{filename}': {e}") | |
| return None |