Spaces:
Sleeping
Sleeping
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import streamlit as st | |
| import torch | |
| from huggingface_hub import PyTorchModelHubMixin | |
| from PIL import Image | |
| from torchvision import transforms | |
| from torchvision.transforms.functional import to_pil_image | |
| from model import ICN | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def mask_processing(x): | |
| if x > 90: | |
| return 140 | |
| elif x < 80: | |
| return 0 | |
| else: | |
| return 255 | |
| def grid_to_heatmap(grid, size=1024): | |
| mask = to_pil_image(grid.view(7, 7)) | |
| mask = mask.resize((size, size), Image.BICUBIC) | |
| mask = Image.eval(mask, mask_processing) | |
| colormap = plt.get_cmap("Wistia") | |
| heatmap = np.array(colormap(mask)) | |
| heatmap = (heatmap * 255).astype(np.uint8) | |
| heatmap = Image.fromarray(heatmap) | |
| return heatmap, mask | |
| def summary_image(img, fake, prediction): | |
| prediction -= prediction.min() | |
| prediction = prediction / prediction.max() | |
| size = 1024 | |
| img1 = img.resize((size, size)) | |
| img2 = fake.resize((size, size)) | |
| heatmap, mask = grid_to_heatmap(prediction) | |
| img1.paste(heatmap, (0, 0), mask) | |
| img2.paste(heatmap, (0, 0), mask) | |
| return img1, img2 | |
| def load_model(): | |
| model = torch.jit.load("traced_model.pt") | |
| model.eval().to(device) | |
| return model | |
| model = ICN.from_pretrained("AlexBlck/image-comparator").eval().to(device) | |
| # model = load_model() | |
| st.title("Image Comparator Network") | |
| st.write("## Upload a pair of images") | |
| cols = st.columns(2) | |
| with cols[0]: | |
| im1 = st.file_uploader("Image 1", type=["jpg", "png"]) | |
| with cols[1]: | |
| im2 = st.file_uploader("Image 2", type=["jpg", "png"]) | |
| if not (im1 and im2): | |
| st.stop() | |
| btn = st.button("Run") | |
| if not btn: | |
| st.stop() | |
| im1 = Image.open(im1).convert("RGB") | |
| im2 = Image.open(im2).convert("RGB") | |
| tr = transforms.Compose( | |
| [ | |
| transforms.Resize(size=(224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ] | |
| ) | |
| img = torch.vstack((tr(im1), tr(im2))).unsqueeze(0) | |
| heatmap, cl = model(img.to(device)) | |
| confs = torch.softmax(cl, dim=1) | |
| pred = torch.argmax(confs, dim=1).item() | |
| if pred == 0: | |
| st.success("No Manipulation Detected") | |
| heatmap *= 0 | |
| elif pred == 1: | |
| st.warning("Manipulation Detected!") | |
| else: | |
| st.error("Images are not related.") | |
| heatmap *= 0 | |
| img1, img2 = summary_image(im1, im2, heatmap[0]) | |
| cols = st.columns(2) | |
| with cols[0]: | |
| st.image(img1) | |
| with cols[1]: | |
| st.image(img2) | |