Spaces:
Runtime error
Runtime error
Update similarity_inference.py
Browse files- similarity_inference.py +82 -53
similarity_inference.py
CHANGED
|
@@ -1,54 +1,83 @@
|
|
| 1 |
-
from image_helpers import convert_images_to_grayscale, crop_center_largest_contour, fetch_similar
|
| 2 |
-
import datasets as ds
|
| 3 |
-
import re
|
| 4 |
-
import torchvision.transforms as T
|
| 5 |
-
from transformers import AutoModel, AutoFeatureExtractor
|
| 6 |
-
import torch
|
| 7 |
-
import random
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
return match_dict
|
|
|
|
| 1 |
+
from image_helpers import convert_images_to_grayscale, crop_center_largest_contour, fetch_similar
|
| 2 |
+
import datasets as ds
|
| 3 |
+
import re
|
| 4 |
+
import torchvision.transforms as T
|
| 5 |
+
from transformers import AutoModel, AutoFeatureExtractor
|
| 6 |
+
import torch
|
| 7 |
+
import random
|
| 8 |
+
import os
|
| 9 |
+
from PIL import Image
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
def similarity_inference(directory):
|
| 13 |
+
|
| 14 |
+
# Get color values for each component
|
| 15 |
+
color_dict = {}
|
| 16 |
+
for each_image in os.listdir(directory):
|
| 17 |
+
image_path = os.path.join(directory, each_image)
|
| 18 |
+
with Image.open(image_path) as img:
|
| 19 |
+
width, height = img.size
|
| 20 |
+
# add 50 random color values to color list
|
| 21 |
+
colors = []
|
| 22 |
+
for i in range(100):
|
| 23 |
+
# choose random pixel
|
| 24 |
+
random_x = random.randint(0, width - 1)
|
| 25 |
+
random_y = random.randint(0, height - 1)
|
| 26 |
+
random_pixel = img.getpixel((random_x, random_y))
|
| 27 |
+
# if pixel is not white
|
| 28 |
+
if random_pixel != (255, 255, 255):
|
| 29 |
+
colors.append(random_pixel)
|
| 30 |
+
colors_array = np.array(colors)
|
| 31 |
+
average_color_value = tuple(np.mean(colors_array, axis=0).astype(int))
|
| 32 |
+
color_dict[each_image] = average_color_value
|
| 33 |
+
|
| 34 |
+
convert_images_to_grayscale(directory)
|
| 35 |
+
crop_center_largest_contour(directory)
|
| 36 |
+
|
| 37 |
+
# define processing variables needed for embedding calculation
|
| 38 |
+
root_directory = "data/" #"C:/Users/josie/OneDrive - Chalmers/Documents/Speckle hackathon/data/"
|
| 39 |
+
model_ckpt = "nateraw/vit-base-beans" ## FIND DIFFERENT MODEL
|
| 40 |
+
candidate_subset_emb = ds.load_dataset("canadianjosieharrison/2024hackathonembeddingdb")['train']
|
| 41 |
+
extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
|
| 42 |
+
model = AutoModel.from_pretrained(model_ckpt)
|
| 43 |
+
transformation_chain = T.Compose(
|
| 44 |
+
[
|
| 45 |
+
# We first resize the input image to 256x256 and then we take center crop.
|
| 46 |
+
T.Resize(int((256 / 224) * extractor.size["height"])),
|
| 47 |
+
T.CenterCrop(extractor.size["height"]),
|
| 48 |
+
T.ToTensor(),
|
| 49 |
+
T.Normalize(mean=extractor.image_mean, std=extractor.image_std),
|
| 50 |
+
])
|
| 51 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 52 |
+
pt_directory = root_directory + "embedding_db.pt" #"materials/embedding_db.pt"
|
| 53 |
+
all_candidate_embeddings = torch.load(pt_directory, map_location=device, weights_only=True)
|
| 54 |
+
candidate_ids = []
|
| 55 |
+
for id in range(len(candidate_subset_emb)):
|
| 56 |
+
# Create a unique indentifier.
|
| 57 |
+
entry = str(id) + "_" + str(random.random()).split('.')[1]
|
| 58 |
+
candidate_ids.append(entry)
|
| 59 |
+
|
| 60 |
+
# load all components
|
| 61 |
+
test_ds = ds.load_dataset("imagefolder", data_dir=directory)
|
| 62 |
+
label_filenames = ds.load_dataset("imagefolder", data_dir=directory).cast_column("image", ds.Image(decode=False))
|
| 63 |
+
|
| 64 |
+
# loop through each component and return top 3 most similar
|
| 65 |
+
match_dict = {"ceiling": [],
|
| 66 |
+
"floor": [],
|
| 67 |
+
"wall": []}
|
| 68 |
+
for i, each_component in enumerate(test_ds['train']):
|
| 69 |
+
query_image = each_component["image"]
|
| 70 |
+
component_label = label_filenames['train'][i]['image']['path'].split('_')[-1].split("\\")[-1]
|
| 71 |
+
rgb_color = color_dict[component_label]
|
| 72 |
+
match = re.search(r"([a-zA-Z]+)(\d*)\.png", component_label)
|
| 73 |
+
component_label = match.group(1)
|
| 74 |
+
segment_id = match.group(2)
|
| 75 |
+
sim_ids = fetch_similar(query_image, transformation_chain, device, model, all_candidate_embeddings, candidate_ids)
|
| 76 |
+
for each_match in sim_ids:
|
| 77 |
+
component_texture_id = str(segment_id) + "-" + str(each_match)
|
| 78 |
+
texture_filename = candidate_subset_emb[each_match]['filenames']
|
| 79 |
+
image_url = f'https://cdn.polyhaven.com/asset_img/thumbs/{texture_filename}?width=256&height=256'
|
| 80 |
+
temp_dict = {"id": component_texture_id, "thumbnail": image_url, "name": texture_filename, "color": str(rgb_color)}
|
| 81 |
+
match_dict[component_label].append(temp_dict)
|
| 82 |
+
|
| 83 |
return match_dict
|