Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch | |
| from joblib import Parallel, delayed | |
| from pytlsd import lsd | |
| from ..base_model import BaseModel | |
| class LSD(BaseModel): | |
| default_conf = { | |
| "min_length": 15, | |
| "max_num_lines": None, | |
| "force_num_lines": False, | |
| "n_jobs": 4, | |
| } | |
| required_data_keys = ["image"] | |
| def _init(self, conf): | |
| if self.conf.force_num_lines: | |
| assert ( | |
| self.conf.max_num_lines is not None | |
| ), "Missing max_num_lines parameter" | |
| def detect_lines(self, img): | |
| # Run LSD | |
| segs = lsd(img) | |
| # Filter out keylines that do not meet the minimum length criteria | |
| lengths = np.linalg.norm(segs[:, 2:4] - segs[:, 0:2], axis=1) | |
| to_keep = lengths >= self.conf.min_length | |
| segs, lengths = segs[to_keep], lengths[to_keep] | |
| # Keep the best lines | |
| scores = segs[:, -1] * np.sqrt(lengths) | |
| segs = segs[:, :4].reshape(-1, 2, 2) | |
| indices = np.argsort(-scores) | |
| if self.conf.max_num_lines is not None: | |
| indices = indices[: self.conf.max_num_lines] | |
| segs = segs[indices] | |
| scores = scores[indices] | |
| # Pad if necessary | |
| n = len(segs) | |
| valid_mask = np.ones(n, dtype=bool) | |
| if self.conf.force_num_lines: | |
| pad = self.conf.max_num_lines - n | |
| segs = np.concatenate( | |
| [segs, np.zeros((pad, 2, 2), dtype=np.float32)], axis=0 | |
| ) | |
| scores = np.concatenate([scores, np.zeros(pad, dtype=np.float32)], axis=0) | |
| valid_mask = np.concatenate([valid_mask, np.zeros(pad, dtype=bool)], axis=0) | |
| return segs, scores, valid_mask | |
| def _forward(self, data): | |
| # Convert to the right data format | |
| image = data["image"] | |
| if image.shape[1] == 3: | |
| # Convert to grayscale | |
| scale = image.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1) | |
| image = (image * scale).sum(1, keepdim=True) | |
| device = image.device | |
| b_size = len(image) | |
| image = np.uint8(image.squeeze(1).cpu().numpy() * 255) | |
| # LSD detection in parallel | |
| if b_size == 1: | |
| lines, line_scores, valid_lines = self.detect_lines(image[0]) | |
| lines = [lines] | |
| line_scores = [line_scores] | |
| valid_lines = [valid_lines] | |
| else: | |
| lines, line_scores, valid_lines = zip( | |
| *Parallel(n_jobs=self.conf.n_jobs)( | |
| delayed(self.detect_lines)(img) for img in image | |
| ) | |
| ) | |
| # Batch if possible | |
| if b_size == 1 or self.conf.force_num_lines: | |
| lines = torch.tensor(lines, dtype=torch.float, device=device) | |
| line_scores = torch.tensor(line_scores, dtype=torch.float, device=device) | |
| valid_lines = torch.tensor(valid_lines, dtype=torch.bool, device=device) | |
| return {"lines": lines, "line_scores": line_scores, "valid_lines": valid_lines} | |
| def loss(self, pred, data): | |
| raise NotImplementedError | |