capjamesg commited on
Commit
612b790
·
verified ·
1 Parent(s): 959a897

add rf-detr seg

Browse files
Files changed (1) hide show
  1. app.py +30 -21
app.py CHANGED
@@ -6,7 +6,7 @@ import gradio as gr
6
  import numpy as np
7
  import supervision as sv
8
  from PIL import Image
9
- from rfdetr import RFDETRNano, RFDETRSmall, RFDETRMedium, RFDETRBase, RFDETRLarge
10
  from rfdetr.detr import RFDETR
11
  from rfdetr.util.coco_classes import COCO_CLASSES
12
 
@@ -25,16 +25,16 @@ by [Roboflow](https://roboflow.com/) and released under the Apache 2.0 license.
25
  """
26
 
27
  IMAGE_PROCESSING_EXAMPLES = [
28
- ['https://media.roboflow.com/supervision/image-examples/people-walking.png', 0.3, 1024, "medium"],
29
- ['https://media.roboflow.com/supervision/image-examples/vehicles.png', 0.3, 1024, "medium"],
30
- ['https://media.roboflow.com/supervision/image-examples/motorbike.png', 0.3, 1024, "medium"],
31
- ['https://media.roboflow.com/notebooks/examples/dog-2.jpeg', 0.5, 512, "nano"],
32
- ['https://media.roboflow.com/notebooks/examples/dog-3.jpeg', 0.5, 512, "nano"],
33
- ['https://media.roboflow.com/supervision/image-examples/basketball-1.png', 0.5, 512, "nano"],
34
  ]
35
  VIDEO_PROCESSING_EXAMPLES = [
36
- ["videos/people-walking.mp4", 0.3, 1024, "medium"],
37
- ["videos/vehicles.mp4", 0.3, 1024, "medium"],
38
  ]
39
 
40
  COLOR = sv.ColorPalette.from_hex([
@@ -53,6 +53,7 @@ def detect_and_annotate(
53
  model: RFDETR,
54
  image: ImageType,
55
  confidence: float,
 
56
  ) -> ImageType:
57
  detections = model.predict(image, threshold=confidence)
58
 
@@ -76,27 +77,34 @@ def detect_and_annotate(
76
  annotated_image = image.copy()
77
  annotated_image = bbox_annotator.annotate(annotated_image, detections)
78
  annotated_image = label_annotator.annotate(annotated_image, detections, labels)
 
 
 
79
  return annotated_image
80
 
81
 
82
  def load_model(resolution: int, checkpoint: str) -> RFDETR:
83
- if checkpoint == "nano":
84
  return RFDETRNano(resolution=resolution)
85
- if checkpoint == "small":
86
  return RFDETRSmall(resolution=resolution)
87
- if checkpoint == "medium":
88
  return RFDETRMedium(resolution=resolution)
89
- if checkpoint == "base":
90
  return RFDETRBase(resolution=resolution)
91
- elif checkpoint == "large":
92
  return RFDETRLarge(resolution=resolution)
 
 
93
  raise TypeError("Checkpoint must be a base or large.")
94
 
95
 
96
  def adjust_resolution(checkpoint: str, resolution: int) -> int:
97
- if checkpoint in {"nano", "small", "medium"}:
 
 
98
  divisor = 32
99
- elif checkpoint in {"base", "large"}:
100
  divisor = 56
101
  else:
102
  raise ValueError(f"Unknown checkpoint: {checkpoint}")
@@ -121,7 +129,7 @@ def image_processing_inference(
121
  ):
122
  resolution = adjust_resolution(checkpoint=checkpoint, resolution=resolution)
123
  model = load_model(resolution=resolution, checkpoint=checkpoint)
124
- return detect_and_annotate(model=model, image=input_image, confidence=confidence)
125
 
126
 
127
  def video_processing_inference(
@@ -149,6 +157,7 @@ def video_processing_inference(
149
  model=model,
150
  image=frame,
151
  confidence=confidence,
 
152
  )
153
  annotated_frame = sv.scale_image(annotated_frame, VIDEO_SCALE_FACTOR)
154
  sink.write_frame(annotated_frame)
@@ -189,8 +198,8 @@ with gr.Blocks() as demo:
189
  )
190
  image_processing_checkpoint_dropdown = gr.Dropdown(
191
  label="Checkpoint",
192
- choices=["nano", "small", "medium"],
193
- value="medium"
194
  )
195
  with gr.Column():
196
  image_processing_submit_button = gr.Button("Submit", value="primary")
@@ -245,8 +254,8 @@ with gr.Blocks() as demo:
245
  )
246
  video_processing_checkpoint_dropdown = gr.Dropdown(
247
  label="Checkpoint",
248
- choices=["nano", "small", "medium"],
249
- value="medium"
250
  )
251
  with gr.Column():
252
  video_processing_submit_button = gr.Button("Submit", value="primary")
 
6
  import numpy as np
7
  import supervision as sv
8
  from PIL import Image
9
+ from rfdetr import RFDETRNano, RFDETRSmall, RFDETRMedium, RFDETRBase, RFDETRLarge, RFDETRSegPreview
10
  from rfdetr.detr import RFDETR
11
  from rfdetr.util.coco_classes import COCO_CLASSES
12
 
 
25
  """
26
 
27
  IMAGE_PROCESSING_EXAMPLES = [
28
+ ['https://media.roboflow.com/supervision/image-examples/people-walking.png', 0.3, 1024, "medium (object detection)"],
29
+ ['https://media.roboflow.com/supervision/image-examples/vehicles.png', 0.3, 1024, "medium (object detection)"],
30
+ ['https://media.roboflow.com/supervision/image-examples/motorbike.png', 0.3, 1024, "medium (object detection)"],
31
+ ['https://media.roboflow.com/notebooks/examples/dog-2.jpeg', 0.5, 512, "nano (object detection)"],
32
+ ['https://media.roboflow.com/notebooks/examples/dog-3.jpeg', 0.5, 512, "nano (object detection)"],
33
+ ['https://media.roboflow.com/supervision/image-examples/basketball-1.png', 0.5, 512, "nano (object detection)"],
34
  ]
35
  VIDEO_PROCESSING_EXAMPLES = [
36
+ ["videos/people-walking.mp4", 0.3, 1024, "medium (object detection)"],
37
+ ["videos/vehicles.mp4", 0.3, 1024, "medium (object detection)"],
38
  ]
39
 
40
  COLOR = sv.ColorPalette.from_hex([
 
53
  model: RFDETR,
54
  image: ImageType,
55
  confidence: float,
56
+ checkpoint: str = "medium (object detection)"
57
  ) -> ImageType:
58
  detections = model.predict(image, threshold=confidence)
59
 
 
77
  annotated_image = image.copy()
78
  annotated_image = bbox_annotator.annotate(annotated_image, detections)
79
  annotated_image = label_annotator.annotate(annotated_image, detections, labels)
80
+ if checkpoint == "segmentation preview":
81
+ mask_annotator = sv.MaskAnnotator()
82
+ annotated_image = mask_annotator.annotate(annotated_image, detections)
83
  return annotated_image
84
 
85
 
86
  def load_model(resolution: int, checkpoint: str) -> RFDETR:
87
+ if checkpoint == "nano (object detection)":
88
  return RFDETRNano(resolution=resolution)
89
+ if checkpoint == "small (object detection)":
90
  return RFDETRSmall(resolution=resolution)
91
+ if checkpoint == "medium (object detection)":
92
  return RFDETRMedium(resolution=resolution)
93
+ if checkpoint == "base (object detection)":
94
  return RFDETRBase(resolution=resolution)
95
+ elif checkpoint == "large (object detection)":
96
  return RFDETRLarge(resolution=resolution)
97
+ elif checkpoint == "segmentation preview":
98
+ return RFDETRSegPreview(resolution=resolution)
99
  raise TypeError("Checkpoint must be a base or large.")
100
 
101
 
102
  def adjust_resolution(checkpoint: str, resolution: int) -> int:
103
+ if checkpoint == "segmentation preview":
104
+ divisor = 24
105
+ elif checkpoint in {"nano (object detection)", "small (object detection)", "medium (object detection)"}:
106
  divisor = 32
107
+ elif checkpoint in {"base (object detection)", "large (object detection)"}:
108
  divisor = 56
109
  else:
110
  raise ValueError(f"Unknown checkpoint: {checkpoint}")
 
129
  ):
130
  resolution = adjust_resolution(checkpoint=checkpoint, resolution=resolution)
131
  model = load_model(resolution=resolution, checkpoint=checkpoint)
132
+ return detect_and_annotate(model=model, image=input_image, confidence=confidence, checkpoint=checkpoint)
133
 
134
 
135
  def video_processing_inference(
 
157
  model=model,
158
  image=frame,
159
  confidence=confidence,
160
+ checkpoint=checkpoint
161
  )
162
  annotated_frame = sv.scale_image(annotated_frame, VIDEO_SCALE_FACTOR)
163
  sink.write_frame(annotated_frame)
 
198
  )
199
  image_processing_checkpoint_dropdown = gr.Dropdown(
200
  label="Checkpoint",
201
+ choices=["nano (object detection)", "small (object detection)", "medium (object detection)", "segmentation preview"],
202
+ value="segmentation preview"
203
  )
204
  with gr.Column():
205
  image_processing_submit_button = gr.Button("Submit", value="primary")
 
254
  )
255
  video_processing_checkpoint_dropdown = gr.Dropdown(
256
  label="Checkpoint",
257
+ choices=["nano (object detection)", "small (object detection)", "medium (object detection)", "segmentation preview"],
258
+ value="segmentation preview"
259
  )
260
  with gr.Column():
261
  video_processing_submit_button = gr.Button("Submit", value="primary")