Spaces:
Sleeping
Sleeping
| import torch | |
| from accelerate.test_utils.testing import get_backend | |
| from PIL import Image | |
| import os | |
| import sys | |
| from config import LOGS_DIR, DEPTH_FM_CHECKPOINT, DEPTH_FM_DIR | |
| sys.path.append(DEPTH_FM_DIR + '/depthfm') | |
| from dfm import DepthFM | |
| from unet import UNetModel | |
| import einops | |
| import numpy as np | |
| from torchvision import transforms | |
| class DepthEstimator: | |
| def __init__(self, image_dir = LOGS_DIR): | |
| self.device,_,_ = get_backend() | |
| self.image_dir = image_dir | |
| self.model = None | |
| def _load_model(self): | |
| if self.model is None: | |
| self.model = DepthFM(DEPTH_FM_CHECKPOINT).to(self.device).eval() | |
| else: | |
| self.model = self.model.to(self.device).eval() | |
| def _unload_model(self): | |
| if self.model is not None: | |
| self.model = self.model.to("cpu") | |
| torch.cuda.empty_cache() | |
| def estimate_depth(self, image_path : str) -> list: | |
| print("Estimating depth...") | |
| predictions_list = [] | |
| self._load_model() | |
| for img in os.listdir(image_path): | |
| if img.endswith(".jpg") or img.endswith(".jpeg") or img.endswith(".png"): | |
| image = Image.open(os.path.join(image_path, img)) | |
| x = np.array(image) | |
| x = einops.rearrange(x, 'h w c -> c h w') | |
| x = x / 127.5 - 1 | |
| x = torch.tensor(x, dtype=torch.float32)[None] | |
| with torch.no_grad(): | |
| depth = self.model.predict_depth(x.to(self.device), num_steps=2, ensemble_size=4) # returns a tensor | |
| depth.cpu() | |
| to_pil = transforms.ToPILImage() | |
| PIL_image = to_pil(depth.squeeze()) | |
| predictions_list.append({"depth": PIL_image}) | |
| del x, depth | |
| torch.cuda.empty_cache() | |
| self._unload_model() | |
| print("Depth estimation complete.") | |
| return predictions_list | |
| def visualize(self, predictions_list : list) -> None: | |
| for (i, prediction) in enumerate(predictions_list): | |
| prediction["depth"].save(f"depth_{i}.png") | |
| # Estimator = DepthEstimator() | |
| # predictions = Estimator.estimate_depth(Estimator.image_dir) | |
| # Estimator.visualize(predictions) | |