Spaces:
Sleeping
Sleeping
| import onnx | |
| import onnxruntime as ort | |
| import numpy as np | |
| import cv2 | |
| import yaml | |
| import copy | |
| class DetectionModel(): | |
| def __init__(self, model_path="weights/best.onnx"): | |
| self.current_input = None | |
| self.latest_output = None | |
| self.model = None | |
| self.model_ckpt = model_path | |
| if self.__check_model(): | |
| self.model = ort.InferenceSession(self.model_ckpt) | |
| else: | |
| raise Exception("Model couldn't be validated using ONNX, please check the checkpoint") | |
| self._load_labels() | |
| def __check_model(self): | |
| model = onnx.load(self.model_ckpt) | |
| try: | |
| onnx.checker.check_model(model) | |
| return True | |
| except: | |
| return False | |
| def _preprepocess_input(self, image: np.ndarray) -> np.ndarray: | |
| """ Preprocess the input image | |
| Resizes the image to 640x640, transposes the matrix so that it's CxHxW and normalizes the image. | |
| Then the result is converted to `np.float32` and returned with the extra `batch` dimension | |
| Args: | |
| image (np.ndarray): The input image | |
| Returns: | |
| processed_image (np.ndarray): The preprocessed image as 1x3x640x640 `np.float32` array | |
| """ | |
| processed_image = copy.deepcopy(image) | |
| processed_image = cv2.resize(processed_image, (640, 640)) | |
| processed_image = processed_image.transpose(2, 0, 1) | |
| processed_image = (processed_image / 255.0).astype(np.float32) | |
| processed_image = np.expand_dims(processed_image, axis=0) | |
| return processed_image | |
| def _postprocess_output(self, predictions) -> np.ndarray: | |
| """ Postprocess the output of the model | |
| Args: | |
| predictions (np.ndarray): The output of the model as a `np.ndarray` | |
| Returns: | |
| detections (np.ndarray): The detections as a `np.ndarray` with shape (N, 6) where N is the number of detections. | |
| The columns are as follows: [x1, y1, x2, y2, confidence, class] | |
| """ | |
| w_ratio = self.current_input.shape[1] / 640 | |
| h_ratio = self.current_input.shape[0] / 640 | |
| detections = [] | |
| for pred in predictions: | |
| # detections.append([int(pred[0]), int(pred[1]), int(pred[2]), int(pred[3]), pred[4], self.ix2l[pred[5]]]) | |
| detections.append([int(pred[0] * w_ratio), int(pred[1] * h_ratio), int(pred[2] * w_ratio), int(pred[3] * h_ratio), pred[4], self.ix2l[pred[5]]]) | |
| return list(detections) | |
| def _load_labels(self): | |
| with open("data.yaml", "r") as f: | |
| data = yaml.safe_load(f) | |
| self.labels = data['names'] | |
| self.l2ix = {l:i for i, l in enumerate(self.labels)} | |
| self.ix2l = {i:l for i, l in enumerate(self.labels)} | |
| def __call__(self, image: np.ndarray): | |
| processed_image = self._preprepocess_input(image) | |
| self.latest_output = list(self.model.run(None, {"images": processed_image})[0][0]) | |
| self.current_input = image | |
| detections = self._postprocess_output(self.latest_output) | |
| return detections | |
| def visualize(self, input_image: np.ndarray, detections: list[list]) -> np.ndarray: | |
| """ Visualizes the detections on the current input image | |
| Args: | |
| detections (list[list]): The detections as a list of lists | |
| Returns: | |
| image (np.ndarray): The image with the detections drawn on it | |
| """ | |
| image = copy.deepcopy(input_image) | |
| for det in detections: | |
| x1, y1, x2, y2, conf, label = det | |
| cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2) | |
| cv2.putText(image, f"{label}: {conf:.3f}", (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2, cv2.LINE_AA) | |
| image = cv2.resize(image, self.current_input.shape[:2][::-1]) | |
| return image |