Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import numpy as np | |
| import tensorflow as tf | |
| import logging | |
| from PIL import Image | |
| from tensorflow.keras.preprocessing import image as keras_image | |
| from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input as resnet_preprocess | |
| from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input as vgg_preprocess | |
| import scipy.fftpack | |
| import time | |
| import clip | |
| import torch | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| # Load models | |
| resnet_model = ResNet50(weights='imagenet', include_top=False, pooling='avg') | |
| vgg_model = VGG16(weights='imagenet', include_top=False, pooling='avg') | |
| clip_model, preprocess_clip = clip.load("ViT-B/32", device="cpu") | |
| # Preprocess function | |
| def preprocess_img(img_path, target_size=(224, 224), preprocess_func=resnet_preprocess): | |
| start_time = time.time() | |
| img = keras_image.load_img(img_path, target_size=target_size) | |
| img_array = keras_image.img_to_array(img) | |
| img_array = np.expand_dims(img_array, axis=0) | |
| img_array = preprocess_func(img_array) | |
| logging.info(f"Image preprocessed in {time.time() - start_time:.4f} seconds") | |
| return img_array | |
| # Feature extraction function | |
| def extract_features(img_path, model, preprocess_func): | |
| img_array = preprocess_img(img_path, preprocess_func=preprocess_func) | |
| start_time = time.time() | |
| features = model.predict(img_array) | |
| logging.info(f"Features extracted in {time.time() - start_time:.4f} seconds") | |
| return features.flatten() | |
| # Calculate cosine similarity | |
| def cosine_similarity(vec1, vec2): | |
| return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)) | |
| # pHash related functions | |
| def phashstr(image, hash_size=8, highfreq_factor=4): | |
| img_size = hash_size * highfreq_factor | |
| image = image.convert('L').resize((img_size, img_size), Image.Resampling.LANCZOS) | |
| pixels = np.asarray(image) | |
| dct = scipy.fftpack.dct(scipy.fftpack.dct(pixels, axis=0), axis=1) | |
| dctlowfreq = dct[:hash_size, :hash_size] | |
| med = np.median(dctlowfreq) | |
| diff = dctlowfreq > med | |
| return _binary_array_to_hex(diff.flatten()) | |
| def _binary_array_to_hex(arr): | |
| h = 0 | |
| s = [] | |
| for i, v in enumerate(arr): | |
| if v: | |
| h += 2**(i % 8) | |
| if (i % 8) == 7: | |
| s.append(hex(h)[2:].rjust(2, '0')) | |
| h = 0 | |
| return ''.join(s) | |
| def hamming_distance(hash1, hash2): | |
| if len(hash1) != len(hash2): | |
| raise ValueError("Hashes must be of the same length") | |
| return sum(c1 != c2 for c1, c2 in zip(hash1, hash2)) | |
| def hamming_to_similarity(distance, hash_length): | |
| return (1 - distance / hash_length) * 100 | |
| # CLIP related functions | |
| def extract_clip_features(image_path, model, preprocess): | |
| image = preprocess(Image.open(image_path)).unsqueeze(0).to("cpu") | |
| with torch.no_grad(): | |
| features = model.encode_image(image) | |
| return features.cpu().numpy().flatten() | |
| # Main function | |
| def compare_images(image1, image2, method): | |
| similarity = None | |
| start_time = time.time() | |
| if method == 'pHash': | |
| img1 = Image.open(image1) | |
| img2 = Image.open(image2) | |
| hash1 = phashstr(img1) | |
| hash2 = phashstr(img2) | |
| distance = hamming_distance(hash1, hash2) | |
| similarity = hamming_to_similarity(distance, len(hash1) * 4) | |
| elif method == 'ResNet50': | |
| features1 = extract_features(image1, resnet_model, resnet_preprocess) | |
| features2 = extract_features(image2, resnet_model, resnet_preprocess) | |
| similarity = cosine_similarity(features1, features2) | |
| elif method == 'VGG16': | |
| features1 = extract_features(image1, vgg_model, vgg_preprocess) | |
| features2 = extract_features(image2, vgg_model, vgg_preprocess) | |
| similarity = cosine_similarity(features1, features2) | |
| elif method == 'CLIP': | |
| features1 = extract_clip_features(image1, clip_model, preprocess_clip) | |
| features2 = extract_clip_features(image2, clip_model, preprocess_clip) | |
| similarity = cosine_similarity(features1, features2) | |
| logging.info(f"AI based Supporting Documents comparison using {method} completed in {time.time() - start_time:.4f} seconds") | |
| # Return similarity with HTML formatting for bold and colorful text | |
| return f"<span style='font-weight:bold; color:#4CAF50;'>Similarity Score: {similarity:.2f}%</span>" | |
| # Gradio interface | |
| demo = gr.Interface( | |
| fn=compare_images, | |
| inputs=[ | |
| gr.Image(type="filepath", label="Upload First Image"), | |
| gr.Image(type="filepath", label="Upload Second Image"), | |
| gr.Radio(["pHash", "ResNet50", "VGG16", "CLIP"], label="Select Comparison Method") | |
| ], | |
| outputs=gr.HTML(label="Similarity"), # Use HTML for bold and colorful text | |
| title="AI Based Customs Supporting Documents Comparison", | |
| description=( | |
| "Upload two images of supporting documents and select the comparison method.\n" | |
| "Fraud documents like invoices are used by custom brokers with the same templates. " | |
| "This tool helps identify similar document templates used in two different consignments.\n" | |
| "Developed by NCTC." | |
| ), | |
| examples=[ | |
| ["Snipaste_2024-05-31_16-18-31.jpg", "Snipaste_2024-05-31_16-18-52.jpg"], | |
| ["example1.png", "example2.png"] | |
| ] | |
| ) | |
| demo.launch() | |