Spaces:
Running
Running
| import numpy as np | |
| import faiss | |
| import torch | |
| from torchvision.transforms import ( | |
| Compose, | |
| Resize, | |
| ToTensor, | |
| Normalize, | |
| InterpolationMode, | |
| CenterCrop, | |
| ) | |
| from PIL import Image | |
| import gradio as gr | |
| print("starting...") | |
| (ys,) = np.load("embs.npz").values() | |
| print("loaded embs") | |
| model = torch.load( | |
| "style-extractor-v0.3.0.ckpt", | |
| map_location="cpu", weights_only=False, | |
| ) | |
| print("loaded extractor") | |
| with open("urls.txt") as f: | |
| urls = f.read().splitlines() | |
| print("loaded urls") | |
| assert len(urls) == len(ys) | |
| d = ys.shape[1] | |
| index = faiss.IndexHNSWFlat(d, 32) | |
| print("building index") | |
| index.add(ys) | |
| print("index built") | |
| def MyResize(area, d): | |
| def f(im: Image): | |
| w, h = im.size | |
| s = (area / w / h) ** 0.5 | |
| wd, hd = int(s * w / d), int(s * h / d) | |
| e = lambda a, b: 1 - min(a, b) / max(a, b) | |
| wd, hd = min( | |
| ( | |
| (ww * d, hh * d) | |
| for ww, hh in [(wd + i, hd + j) for i in (0, 1) for j in (0, 1)] | |
| if ww * d * hh * d <= area | |
| ), | |
| key=lambda wh: e(wh[0] / wh[1], w / h), | |
| ) | |
| return Compose( | |
| [ | |
| Resize( | |
| (int(h * wd / w), wd) if wd / w > hd / h else (hd, int(w * hd / h)), | |
| InterpolationMode.BICUBIC, | |
| ), | |
| CenterCrop((hd, wd)), | |
| ] | |
| )(im) | |
| return f | |
| tf = Compose( | |
| [ | |
| MyResize((518 * 1.3) ** 2, 14), | |
| ToTensor(), | |
| Normalize(mean=[0.4850, 0.4560, 0.4060], std=[0.2290, 0.2240, 0.2250]), | |
| ] | |
| ) | |
| def get_emb(im: Image): | |
| model.eval() | |
| with torch.no_grad(): | |
| return model(tf(im).unsqueeze(0)) | |
| n_outputs = 50 | |
| row_size = 5 | |
| def f(im): | |
| D, I = index.search(get_emb(im), n_outputs) | |
| return [f"Distance: {d:.1f}\n" for d, i in zip(D[0], I[0])] | |
| print("preparing gradio") | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| "# Style Similarity Search\n\nFind artworks with a similar style from a medium-sized database (10k artists * 30 img/artist)" | |
| ) | |
| img = gr.Image(type="pil", label="Query", height=500) | |
| btn = gr.Button(variant="primary", value="search") | |
| outputs = [] | |
| for i in range(-(n_outputs // (-row_size))): | |
| with gr.Row(): | |
| for _ in range(min(row_size, n_outputs - i * row_size)): | |
| outputs.append(gr.Markdown(label=f"#{len(outputs) + 1}")) | |
| btn.click(f, img, outputs) | |
| print("starting gradio") | |
| demo.launch() | |