Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -4,6 +4,7 @@ import cv2
|
|
| 4 |
import numpy as np
|
| 5 |
from transformers import SamModel, SamProcessor, BlipProcessor, BlipForConditionalGeneration
|
| 6 |
from PIL import Image
|
|
|
|
| 7 |
|
| 8 |
# Set up device
|
| 9 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
@@ -26,6 +27,26 @@ def process_mask(mask, target_size):
|
|
| 26 |
mask_image = mask_image.resize(target_size, Image.NEAREST)
|
| 27 |
return np.array(mask_image) > 0
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
def segment_image(input_image, object_name):
|
| 30 |
try:
|
| 31 |
if input_image is None:
|
|
@@ -36,9 +57,9 @@ def segment_image(input_image, object_name):
|
|
| 36 |
if not original_size or 0 in original_size:
|
| 37 |
return None, "Invalid image size. Please upload a different image."
|
| 38 |
|
| 39 |
-
# Generate image caption
|
| 40 |
blip_inputs = blip_processor(input_image, return_tensors="pt").to(device)
|
| 41 |
-
caption = blip_model.generate(**blip_inputs)
|
| 42 |
caption_text = blip_processor.decode(caption[0], skip_special_tokens=True)
|
| 43 |
|
| 44 |
# Process the image with SAM
|
|
@@ -58,15 +79,21 @@ def segment_image(input_image, object_name):
|
|
| 58 |
# Find the mask that best matches the specified object
|
| 59 |
best_mask = None
|
| 60 |
best_score = -1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
for mask in masks[0]:
|
| 62 |
mask_binary = mask.numpy() > 0.5
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
|
|
|
| 67 |
|
| 68 |
if best_mask is None:
|
| 69 |
-
return input_image, f"Could not find '{object_name}' in the image."
|
| 70 |
|
| 71 |
combined_mask = process_mask(best_mask, original_size)
|
| 72 |
|
|
|
|
| 4 |
import numpy as np
|
| 5 |
from transformers import SamModel, SamProcessor, BlipProcessor, BlipForConditionalGeneration
|
| 6 |
from PIL import Image
|
| 7 |
+
from scipy.ndimage import label, center_of_mass
|
| 8 |
|
| 9 |
# Set up device
|
| 10 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 27 |
mask_image = mask_image.resize(target_size, Image.NEAREST)
|
| 28 |
return np.array(mask_image) > 0
|
| 29 |
|
| 30 |
+
def is_cat_like(mask, image_area):
|
| 31 |
+
labeled, num_features = label(mask)
|
| 32 |
+
if num_features == 0:
|
| 33 |
+
return False
|
| 34 |
+
|
| 35 |
+
largest_component = (labeled == (np.bincount(labeled.flatten())[1:].argmax() + 1))
|
| 36 |
+
area = largest_component.sum()
|
| 37 |
+
|
| 38 |
+
# Check if the area is reasonable for a cat (between 5% and 30% of image)
|
| 39 |
+
if not (0.05 * image_area < area < 0.3 * image_area):
|
| 40 |
+
return False
|
| 41 |
+
|
| 42 |
+
# Check if the shape is roughly elliptical
|
| 43 |
+
cy, cx = center_of_mass(largest_component)
|
| 44 |
+
major_axis = max(largest_component.shape)
|
| 45 |
+
minor_axis = min(largest_component.shape)
|
| 46 |
+
aspect_ratio = major_axis / minor_axis
|
| 47 |
+
|
| 48 |
+
return 1.5 < aspect_ratio < 3 # Most cats have an aspect ratio in this range
|
| 49 |
+
|
| 50 |
def segment_image(input_image, object_name):
|
| 51 |
try:
|
| 52 |
if input_image is None:
|
|
|
|
| 57 |
if not original_size or 0 in original_size:
|
| 58 |
return None, "Invalid image size. Please upload a different image."
|
| 59 |
|
| 60 |
+
# Generate detailed image caption
|
| 61 |
blip_inputs = blip_processor(input_image, return_tensors="pt").to(device)
|
| 62 |
+
caption = blip_model.generate(**blip_inputs, max_length=50)
|
| 63 |
caption_text = blip_processor.decode(caption[0], skip_special_tokens=True)
|
| 64 |
|
| 65 |
# Process the image with SAM
|
|
|
|
| 79 |
# Find the mask that best matches the specified object
|
| 80 |
best_mask = None
|
| 81 |
best_score = -1
|
| 82 |
+
image_area = original_size[0] * original_size[1]
|
| 83 |
+
|
| 84 |
+
cat_related_words = ['cat', 'kitten', 'feline', 'tabby', 'kitty']
|
| 85 |
+
caption_contains_cat = any(word in caption_text.lower() for word in cat_related_words)
|
| 86 |
+
|
| 87 |
for mask in masks[0]:
|
| 88 |
mask_binary = mask.numpy() > 0.5
|
| 89 |
+
if is_cat_like(mask_binary, image_area) and caption_contains_cat:
|
| 90 |
+
mask_area = mask_binary.sum()
|
| 91 |
+
if mask_area > best_score:
|
| 92 |
+
best_mask = mask_binary
|
| 93 |
+
best_score = mask_area
|
| 94 |
|
| 95 |
if best_mask is None:
|
| 96 |
+
return input_image, f"Could not find a suitable '{object_name}' in the image."
|
| 97 |
|
| 98 |
combined_mask = process_mask(best_mask, original_size)
|
| 99 |
|