Spaces:
Sleeping
Sleeping
| import random | |
| from typing import List, Optional, Union | |
| import numpy as np | |
| import torch | |
| from PIL import Image, ImageDraw, ImageFont | |
| from torchvision.transforms.functional import to_pil_image | |
| from yolo.config.config import ModelConfig | |
| from yolo.model.yolo import YOLO | |
| from yolo.utils.logger import logger | |
| def draw_bboxes( | |
| img: Union[Image.Image, torch.Tensor], | |
| bboxes: List[List[Union[int, float]]], | |
| *, | |
| idx2label: Optional[list] = None, | |
| ): | |
| """ | |
| Draw bounding boxes on an image. | |
| Args: | |
| - img (PIL Image or torch.Tensor): Image on which to draw the bounding boxes. | |
| - bboxes (List of Lists/Tensors): Bounding boxes with [class_id, x_min, y_min, x_max, y_max], | |
| where coordinates are normalized [0, 1]. | |
| """ | |
| # Convert tensor image to PIL Image if necessary | |
| if isinstance(img, torch.Tensor): | |
| if img.dim() > 3: | |
| logger.warning("π >3 dimension tensor detected, using the 0-idx image.") | |
| img = img[0] | |
| img = to_pil_image(img) | |
| if isinstance(bboxes, list) or bboxes.ndim == 3: | |
| bboxes = bboxes[0] | |
| img = img.copy() | |
| label_size = img.size[1] / 30 | |
| draw = ImageDraw.Draw(img, "RGBA") | |
| try: | |
| font = ImageFont.truetype("arial.ttf", int(label_size)) | |
| except IOError: | |
| font = ImageFont.load_default(int(label_size)) | |
| for bbox in bboxes: | |
| class_id, x_min, y_min, x_max, y_max, *conf = [float(val) for val in bbox] | |
| x_min, x_max = min(x_min, x_max), max(x_min, x_max) | |
| y_min, y_max = min(y_min, y_max), max(y_min, y_max) | |
| bbox = [(x_min, y_min), (x_max, y_max)] | |
| random.seed(int(class_id)) | |
| color_map = (random.randint(0, 200), random.randint(0, 200), random.randint(0, 200)) | |
| draw.rounded_rectangle(bbox, outline=(*color_map, 200), radius=5, width=2) | |
| draw.rounded_rectangle(bbox, fill=(*color_map, 100), radius=5) | |
| class_text = str(idx2label[int(class_id)] if idx2label else int(class_id)) | |
| label_text = f"{class_text}" + (f" {conf[0]: .0%}" if conf else "") | |
| text_bbox = font.getbbox(label_text) | |
| text_width = text_bbox[2] - text_bbox[0] | |
| text_height = (text_bbox[3] - text_bbox[1]) * 1.5 | |
| text_background = [(x_min, y_min), (x_min + text_width, y_min + text_height)] | |
| draw.rounded_rectangle(text_background, fill=(*color_map, 175), radius=2) | |
| draw.text((x_min, y_min), label_text, fill="white", font=font) | |
| return img | |
| def draw_model(*, model_cfg: ModelConfig = None, model: YOLO = None, v7_base=False): | |
| from graphviz import Digraph | |
| if model_cfg: | |
| from yolo.model.yolo import create_model | |
| model = create_model(model_cfg) | |
| elif model is None: | |
| raise ValueError("Drawing Object is None") | |
| model_size = len(model.model) + 1 | |
| model_mat = np.zeros((model_size, model_size), dtype=bool) | |
| layer_name = ["INPUT"] | |
| for idx, layer in enumerate(model.model, start=1): | |
| layer_name.append(str(type(layer)).split(".")[-1][:-2]) | |
| if layer.tags is not None: | |
| layer_name[-1] = f"{layer.tags}-{layer_name[-1]}" | |
| if isinstance(layer.source, int): | |
| source = layer.source + (layer.source < 0) * idx | |
| model_mat[source, idx] = True | |
| else: | |
| for source in layer.source: | |
| source = source + (source < 0) * idx | |
| model_mat[source, idx] = True | |
| pattern_mat = [] | |
| if v7_base: | |
| pattern_list = [("ELAN", 8, 3), ("ELAN", 8, 55), ("MP", 5, 11)] | |
| for name, size, position in pattern_list: | |
| pattern_mat.append( | |
| (name, size, model_mat[position : position + size, position + 1 : position + 1 + size].copy()) | |
| ) | |
| dot = Digraph(comment="Model Flow Chart") | |
| node_idx = 0 | |
| for idx in range(model_size): | |
| for jdx in range(idx, model_size - 7): | |
| for name, size, pattern in pattern_mat: | |
| if (model_mat[idx : idx + size, jdx : jdx + size] == pattern).all(): | |
| layer_name[idx] = name | |
| model_mat[idx : idx + size, jdx : jdx + size] = False | |
| model_mat[idx, idx + size] = True | |
| dot.node(str(idx), f"{layer_name[idx]}") | |
| node_idx += 1 | |
| for jdx in range(idx, model_size): | |
| if model_mat[idx, jdx]: | |
| dot.edge(str(idx), str(jdx)) | |
| try: | |
| dot.render("Model-arch", format="png", cleanup=True) | |
| logger.info(":artist_palette: Drawing Model Architecture at Model-arch.png") | |
| except: | |
| logger.warning(":warning: Could not find graphviz backend, continue without drawing the model architecture") | |