Spaces:
Runtime error
Runtime error
| import os | |
| import clip | |
| import numpy as np | |
| import torch | |
| import torchvision.transforms as T | |
| from PIL import Image | |
| RESOURCES_ROOT = "scripts/util/detection/" | |
| def predict_proba(X, weights, biases): | |
| logits = X @ weights.T + biases | |
| proba = np.where( | |
| logits >= 0, 1 / (1 + np.exp(-logits)), np.exp(logits) / (1 + np.exp(logits)) | |
| ) | |
| return proba.T | |
| def load_model_weights(path: str): | |
| model_weights = np.load(path) | |
| return model_weights["weights"], model_weights["biases"] | |
| def clip_process_images(images: torch.Tensor) -> torch.Tensor: | |
| min_size = min(images.shape[-2:]) | |
| return T.Compose( | |
| [ | |
| T.CenterCrop(min_size), # TODO: this might affect the watermark, check this | |
| T.Resize(224, interpolation=T.InterpolationMode.BICUBIC, antialias=True), | |
| T.Normalize( | |
| (0.48145466, 0.4578275, 0.40821073), | |
| (0.26862954, 0.26130258, 0.27577711), | |
| ), | |
| ] | |
| )(images) | |
| class DeepFloydDataFiltering(object): | |
| def __init__( | |
| self, verbose: bool = False, device: torch.device = torch.device("cpu") | |
| ): | |
| super().__init__() | |
| self.verbose = verbose | |
| self._device = None | |
| self.clip_model, _ = clip.load("ViT-L/14", device=device) | |
| self.clip_model.eval() | |
| self.cpu_w_weights, self.cpu_w_biases = load_model_weights( | |
| os.path.join(RESOURCES_ROOT, "w_head_v1.npz") | |
| ) | |
| self.cpu_p_weights, self.cpu_p_biases = load_model_weights( | |
| os.path.join(RESOURCES_ROOT, "p_head_v1.npz") | |
| ) | |
| self.w_threshold, self.p_threshold = 0.5, 0.5 | |
| def __call__(self, images: torch.Tensor) -> torch.Tensor: | |
| imgs = clip_process_images(images) | |
| if self._device is None: | |
| self._device = next(p for p in self.clip_model.parameters()).device | |
| image_features = self.clip_model.encode_image(imgs.to(self._device)) | |
| image_features = image_features.detach().cpu().numpy().astype(np.float16) | |
| p_pred = predict_proba(image_features, self.cpu_p_weights, self.cpu_p_biases) | |
| w_pred = predict_proba(image_features, self.cpu_w_weights, self.cpu_w_biases) | |
| print(f"p_pred = {p_pred}, w_pred = {w_pred}") if self.verbose else None | |
| query = p_pred > self.p_threshold | |
| if query.sum() > 0: | |
| print(f"Hit for p_threshold: {p_pred}") if self.verbose else None | |
| images[query] = T.GaussianBlur(99, sigma=(100.0, 100.0))(images[query]) | |
| query = w_pred > self.w_threshold | |
| if query.sum() > 0: | |
| print(f"Hit for w_threshold: {w_pred}") if self.verbose else None | |
| images[query] = T.GaussianBlur(99, sigma=(100.0, 100.0))(images[query]) | |
| return images | |
| def load_img(path: str) -> torch.Tensor: | |
| image = Image.open(path) | |
| if not image.mode == "RGB": | |
| image = image.convert("RGB") | |
| image_transforms = T.Compose( | |
| [ | |
| T.ToTensor(), | |
| ] | |
| ) | |
| return image_transforms(image)[None, ...] | |
| def test(root): | |
| from einops import rearrange | |
| filter = DeepFloydDataFiltering(verbose=True) | |
| for p in os.listdir((root)): | |
| print(f"running on {p}...") | |
| img = load_img(os.path.join(root, p)) | |
| filtered_img = filter(img) | |
| filtered_img = rearrange( | |
| 255.0 * (filtered_img.numpy())[0], "c h w -> h w c" | |
| ).astype(np.uint8) | |
| Image.fromarray(filtered_img).save( | |
| os.path.join(root, f"{os.path.splitext(p)[0]}-filtered.jpg") | |
| ) | |
| if __name__ == "__main__": | |
| import fire | |
| fire.Fire(test) | |
| print("done.") | |