Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | 
         @@ -6,34 +6,51 @@ from torchvision import transforms 
     | 
|
| 6 | 
         
             
            from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
         
     | 
| 7 | 
         
             
            import matplotlib.pyplot as plt
         
     | 
| 8 | 
         
             
            import gradio as gr
         
     | 
| 
         | 
|
| 9 | 
         
             
            # import segmentation_models_pytorch as smp
         
     | 
| 10 | 
         | 
| 
         | 
|
| 
         | 
|
| 11 | 
         | 
| 12 | 
         | 
| 13 | 
         
             
            # image= cv2.imread('image_4.png', cv2.IMREAD_COLOR)
         
     | 
| 14 | 
         
            -
            def get_masks( 
     | 
| 15 | 
         
            -
                 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 16 | 
         
             
                    sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
         
     | 
| 17 | 
         
            -
             
     | 
| 18 | 
         
            -
                if model_type,all() == 'vit_b':
         
     | 
| 19 | 
         
             
                    sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
         
     | 
| 20 | 
         | 
| 21 | 
         
            -
                if model_type 
     | 
| 22 | 
         
             
                    sam = sam_model_registry["vit_l"](checkpoint="sam_vit_l_0b3195.pth")
         
     | 
| 
         | 
|
| 
         | 
|
| 23 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 24 | 
         
             
                mask_generator = SamAutomaticMaskGenerator(sam)
         
     | 
| 25 | 
         
             
                masks = mask_generator.generate(image)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 26 | 
         
             
                for i, mask_data in enumerate(masks):
         
     | 
| 27 | 
         
             
                    mask = mask_data['segmentation']
         
     | 
| 28 | 
         
             
                    color = colors[i]
         
     | 
| 29 | 
         
             
                    composite_image[mask] = (color[:3] * 255).astype(np.uint8)  # Apply color to mask
         
     | 
| 
         | 
|
| 30 | 
         | 
| 31 | 
         
             
                # Combine original image with the composite mask image
         
     | 
| 32 | 
         
            -
                overlayed_image = (composite_image * 0.5 +  
     | 
| 
         | 
|
| 33 | 
         
             
                return overlayed_image
         
     | 
| 34 | 
         | 
| 35 | 
         | 
| 36 | 
         | 
| 
         | 
|
| 37 | 
         
             
            iface = gr.Interface(
         
     | 
| 38 | 
         
             
                fn=get_masks,
         
     | 
| 39 | 
         
             
                inputs=["image", gr.components.Dropdown(choices=['vit_h', 'vit_b', 'vit_l'], label="Model Type")],
         
     | 
| 
         @@ -42,5 +59,4 @@ iface = gr.Interface( 
     | 
|
| 42 | 
         
             
                description="Upload an image, select a model type, and receive the segmented and classified parts."
         
     | 
| 43 | 
         
             
            )
         
     | 
| 44 | 
         | 
| 45 | 
         
            -
             
     | 
| 46 | 
         
             
            iface.launch()
         
     | 
| 
         | 
|
| 6 | 
         
             
            from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
         
     | 
| 7 | 
         
             
            import matplotlib.pyplot as plt
         
     | 
| 8 | 
         
             
            import gradio as gr
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
             
            # import segmentation_models_pytorch as smp
         
     | 
| 11 | 
         | 
| 12 | 
         
            +
            ##set the device to cuda for sam model
         
     | 
| 13 | 
         
            +
            # device = torch.device('cuda')
         
     | 
| 14 | 
         | 
| 15 | 
         | 
| 16 | 
         
             
            # image= cv2.imread('image_4.png', cv2.IMREAD_COLOR)
         
     | 
| 17 | 
         
            +
            def get_masks( image, model_type):
         
     | 
| 18 | 
         
            +
                print(image)
         
     | 
| 19 | 
         
            +
                # image_pil = Image.fromarray(image.astype('uint8'), 'RGB')
         
     | 
| 20 | 
         
            +
                # print(image_pil)
         
     | 
| 21 | 
         
            +
                if model_type == 'vit_h':
         
     | 
| 22 | 
         
             
                    sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
         
     | 
| 23 | 
         
            +
                if model_type == 'vit_b':
         
     | 
| 
         | 
|
| 24 | 
         
             
                    sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
         
     | 
| 25 | 
         | 
| 26 | 
         
            +
                if model_type == 'vit_l':
         
     | 
| 27 | 
         
             
                    sam = sam_model_registry["vit_l"](checkpoint="sam_vit_l_0b3195.pth")
         
     | 
| 28 | 
         
            +
                else:
         
     | 
| 29 | 
         
            +
                    sam=  sam_model_registry["vit_l"](checkpoint="sam_vit_l_0b3195.pth")
         
     | 
| 30 | 
         | 
| 31 | 
         
            +
                # print(image.shape)
         
     | 
| 32 | 
         
            +
                #set the device to cuda for sam model
         
     | 
| 33 | 
         
            +
                # sam.to(device= device)
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
             
                mask_generator = SamAutomaticMaskGenerator(sam)
         
     | 
| 36 | 
         
             
                masks = mask_generator.generate(image)
         
     | 
| 37 | 
         
            +
                composite_image = np.zeros_like(image)
         
     | 
| 38 | 
         
            +
                colors = plt.cm.jet(np.linspace(0, 1, len(masks)))  # Generate distinct colors
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
             
                for i, mask_data in enumerate(masks):
         
     | 
| 41 | 
         
             
                    mask = mask_data['segmentation']
         
     | 
| 42 | 
         
             
                    color = colors[i]
         
     | 
| 43 | 
         
             
                    composite_image[mask] = (color[:3] * 255).astype(np.uint8)  # Apply color to mask
         
     | 
| 44 | 
         
            +
                print(composite_image.shape, image.shape)
         
     | 
| 45 | 
         | 
| 46 | 
         
             
                # Combine original image with the composite mask image
         
     | 
| 47 | 
         
            +
                overlayed_image = (composite_image * 0.5 + torch.from_numpy(image).resize(738, 1200, 3).cpu().numpy() * 0.5).astype(np.uint8)
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
             
                return overlayed_image
         
     | 
| 50 | 
         | 
| 51 | 
         | 
| 52 | 
         | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
             
            iface = gr.Interface(
         
     | 
| 55 | 
         
             
                fn=get_masks,
         
     | 
| 56 | 
         
             
                inputs=["image", gr.components.Dropdown(choices=['vit_h', 'vit_b', 'vit_l'], label="Model Type")],
         
     | 
| 
         | 
|
| 59 | 
         
             
                description="Upload an image, select a model type, and receive the segmented and classified parts."
         
     | 
| 60 | 
         
             
            )
         
     | 
| 61 | 
         | 
| 
         | 
|
| 62 | 
         
             
            iface.launch()
         
     |