humbleakh commited on
Commit
518a32b
·
verified ·
1 Parent(s): 9d3b1eb
Files changed (2) hide show
  1. app.py +56 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
4
+ import numpy as np
5
+ from PIL import Image
6
+
7
+ # Load model and processor
8
+ model_name = "nvidia/segformer-b0-finetuned-ade-512-512"
9
+ processor = SegformerImageProcessor.from_pretrained(model_name)
10
+ model = SegformerForSemanticSegmentation.from_pretrained(model_name)
11
+
12
+ # ADE20k color palette (simplified for visualization)
13
+ def create_color_map():
14
+ # Create a colorful palette for visualization
15
+ np.random.seed(42)
16
+ return np.random.randint(0, 255, size=(150, 3), dtype=np.uint8)
17
+
18
+ color_map = create_color_map()
19
+
20
+ def segment_image(image):
21
+ # Process the image
22
+ inputs = processor(images=image, return_tensors="pt")
23
+
24
+ # Get model prediction
25
+ with torch.no_grad():
26
+ outputs = model(**inputs)
27
+ logits = outputs.logits
28
+
29
+ # Get segmentation map
30
+ seg = logits.argmax(dim=1)[0].cpu().numpy()
31
+
32
+ # Convert to colored segmentation map
33
+ colored_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
34
+ for label in np.unique(seg):
35
+ colored_seg[seg == label] = color_map[label]
36
+
37
+ # Convert to PIL image for display
38
+ segmented_image = Image.fromarray(colored_seg)
39
+
40
+ return [image, segmented_image]
41
+
42
+ # Create Gradio interface
43
+ demo = gr.Interface(
44
+ fn=segment_image,
45
+ inputs=gr.Image(type="pil"),
46
+ outputs=[
47
+ gr.Image(type="pil", label="Original"),
48
+ gr.Image(type="pil", label="Segmented")
49
+ ],
50
+ title="Image Segmentation with SegFormer",
51
+ description="Upload an image to segment it into different semantic regions using SegFormer model fine-tuned on ADE20K dataset."
52
+ )
53
+
54
+ # Launch the app
55
+ if __name__ == "__main__":
56
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio>=3.50.2
2
+ transformers>=4.30.0
3
+ torch>=2.0.0
4
+ Pillow>=9.5.0
5
+ numpy>=1.24.0