Spaces:
Paused
Paused
| import folder_paths | |
| import impact.mmdet_nodes as mmdet_nodes | |
| from impact.utils import * | |
| from impact.core import SEG | |
| import impact.core as core | |
| import nodes | |
| class NO_BBOX_MODEL: | |
| pass | |
| class NO_SEGM_MODEL: | |
| pass | |
| class MMDetLoader: | |
| def INPUT_TYPES(s): | |
| bboxs = ["bbox/"+x for x in folder_paths.get_filename_list("mmdets_bbox")] | |
| segms = ["segm/"+x for x in folder_paths.get_filename_list("mmdets_segm")] | |
| return {"required": {"model_name": (bboxs + segms, )}} | |
| RETURN_TYPES = ("BBOX_MODEL", "SEGM_MODEL") | |
| FUNCTION = "load_mmdet" | |
| CATEGORY = "ImpactPack/Legacy" | |
| def load_mmdet(self, model_name): | |
| mmdet_path = folder_paths.get_full_path("mmdets", model_name) | |
| model = mmdet_nodes.load_mmdet(mmdet_path) | |
| if model_name.startswith("bbox"): | |
| return model, NO_SEGM_MODEL() | |
| else: | |
| return NO_BBOX_MODEL(), model | |
| class BboxDetectorForEach: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "bbox_model": ("BBOX_MODEL", ), | |
| "image": ("IMAGE", ), | |
| "threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), | |
| "dilation": ("INT", {"default": 10, "min": 0, "max": 255, "step": 1}), | |
| "crop_factor": ("FLOAT", {"default": 3.0, "min": 1.0, "max": 100, "step": 0.1}), | |
| } | |
| } | |
| RETURN_TYPES = ("SEGS", ) | |
| FUNCTION = "doit" | |
| CATEGORY = "ImpactPack/Legacy" | |
| def detect(bbox_model, image, threshold, dilation, crop_factor, drop_size=1, detailer_hook=None): | |
| mmdet_results = mmdet_nodes.inference_bbox(bbox_model, image, threshold) | |
| segmasks = core.create_segmasks(mmdet_results) | |
| if dilation > 0: | |
| segmasks = dilate_masks(segmasks, dilation) | |
| items = [] | |
| h = image.shape[1] | |
| w = image.shape[2] | |
| for x in segmasks: | |
| item_bbox = x[0] | |
| item_mask = x[1] | |
| y1, x1, y2, x2 = item_bbox | |
| if x2 - x1 > drop_size and y2 - y1 > drop_size: | |
| crop_region = make_crop_region(w, h, item_bbox, crop_factor) | |
| cropped_image = crop_image(image, crop_region) | |
| cropped_mask = crop_ndarray2(item_mask, crop_region) | |
| confidence = x[2] | |
| # bbox_size = (item_bbox[2]-item_bbox[0],item_bbox[3]-item_bbox[1]) # (w,h) | |
| item = SEG(cropped_image, cropped_mask, confidence, crop_region, item_bbox, None, None) | |
| items.append(item) | |
| shape = h, w | |
| return shape, items | |
| def doit(self, bbox_model, image, threshold, dilation, crop_factor): | |
| return (BboxDetectorForEach.detect(bbox_model, image, threshold, dilation, crop_factor), ) | |
| class SegmDetectorCombined: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "segm_model": ("SEGM_MODEL", ), | |
| "image": ("IMAGE", ), | |
| "threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), | |
| "dilation": ("INT", {"default": 0, "min": 0, "max": 255, "step": 1}), | |
| } | |
| } | |
| RETURN_TYPES = ("MASK",) | |
| FUNCTION = "doit" | |
| CATEGORY = "ImpactPack/Legacy" | |
| def doit(self, segm_model, image, threshold, dilation): | |
| mmdet_results = mmdet_nodes.inference_segm(image, segm_model, threshold) | |
| segmasks = core.create_segmasks(mmdet_results) | |
| if dilation > 0: | |
| segmasks = dilate_masks(segmasks, dilation) | |
| mask = combine_masks(segmasks) | |
| return (mask,) | |
| class BboxDetectorCombined(SegmDetectorCombined): | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "bbox_model": ("BBOX_MODEL", ), | |
| "image": ("IMAGE", ), | |
| "threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), | |
| "dilation": ("INT", {"default": 4, "min": 0, "max": 255, "step": 1}), | |
| } | |
| } | |
| def doit(self, bbox_model, image, threshold, dilation): | |
| mmdet_results = mmdet_nodes.inference_bbox(bbox_model, image, threshold) | |
| segmasks = core.create_segmasks(mmdet_results) | |
| if dilation > 0: | |
| segmasks = dilate_masks(segmasks, dilation) | |
| mask = combine_masks(segmasks) | |
| return (mask,) | |
| class SegmDetectorForEach: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "segm_model": ("SEGM_MODEL", ), | |
| "image": ("IMAGE", ), | |
| "threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), | |
| "dilation": ("INT", {"default": 10, "min": 0, "max": 255, "step": 1}), | |
| "crop_factor": ("FLOAT", {"default": 3.0, "min": 1.0, "max": 100, "step": 0.1}), | |
| } | |
| } | |
| RETURN_TYPES = ("SEGS", ) | |
| FUNCTION = "doit" | |
| CATEGORY = "ImpactPack/Legacy" | |
| def doit(self, segm_model, image, threshold, dilation, crop_factor): | |
| mmdet_results = mmdet_nodes.inference_segm(image, segm_model, threshold) | |
| segmasks = core.create_segmasks(mmdet_results) | |
| if dilation > 0: | |
| segmasks = dilate_masks(segmasks, dilation) | |
| items = [] | |
| h = image.shape[1] | |
| w = image.shape[2] | |
| for x in segmasks: | |
| item_bbox = x[0] | |
| item_mask = x[1] | |
| crop_region = make_crop_region(w, h, item_bbox, crop_factor) | |
| cropped_image = crop_image(image, crop_region) | |
| cropped_mask = crop_ndarray2(item_mask, crop_region) | |
| confidence = x[2] | |
| item = SEG(cropped_image, cropped_mask, confidence, crop_region, item_bbox, None, None) | |
| items.append(item) | |
| shape = h,w | |
| return ((shape, items), ) | |
| class SegsMaskCombine: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "segs": ("SEGS", ), | |
| "image": ("IMAGE", ), | |
| } | |
| } | |
| RETURN_TYPES = ("MASK",) | |
| FUNCTION = "doit" | |
| CATEGORY = "ImpactPack/Legacy" | |
| def combine(segs, image): | |
| h = image.shape[1] | |
| w = image.shape[2] | |
| mask = np.zeros((h, w), dtype=np.uint8) | |
| for seg in segs[1]: | |
| cropped_mask = seg.cropped_mask | |
| crop_region = seg.crop_region | |
| mask[crop_region[1]:crop_region[3], crop_region[0]:crop_region[2]] |= (cropped_mask * 255).astype(np.uint8) | |
| return torch.from_numpy(mask.astype(np.float32) / 255.0) | |
| def doit(self, segs, image): | |
| return (SegsMaskCombine.combine(segs, image), ) | |
| class MaskPainter(nodes.PreviewImage): | |
| def INPUT_TYPES(s): | |
| return {"required": {"images": ("IMAGE",), }, | |
| "hidden": { | |
| "prompt": "PROMPT", | |
| "extra_pnginfo": "EXTRA_PNGINFO", | |
| }, | |
| "optional": {"mask_image": ("IMAGE_PATH",), }, | |
| "optional": {"image": (["#placeholder"], )}, | |
| } | |
| RETURN_TYPES = ("MASK",) | |
| FUNCTION = "save_painted_images" | |
| CATEGORY = "ImpactPack/Legacy" | |
| def save_painted_images(self, images, filename_prefix="impact-mask", | |
| prompt=None, extra_pnginfo=None, mask_image=None, image=None): | |
| if image == "#placeholder" or image['image_hash'] != id(images): | |
| # new input image | |
| res = self.save_images(images, filename_prefix, prompt, extra_pnginfo) | |
| item = res['ui']['images'][0] | |
| if not item['filename'].endswith(']'): | |
| filepath = f"{item['filename']} [{item['type']}]" | |
| else: | |
| filepath = item['filename'] | |
| _, mask = nodes.LoadImage().load_image(filepath) | |
| res['ui']['aux'] = [id(images), res['ui']['images']] | |
| res['result'] = (mask, ) | |
| return res | |
| else: | |
| # new mask | |
| if '0' in image: # fallback | |
| image = image['0'] | |
| forward = {'filename': image['forward_filename'], | |
| 'subfolder': image['forward_subfolder'], | |
| 'type': image['forward_type'], } | |
| res = {'ui': {'images': [forward]}} | |
| imgpath = "" | |
| if 'subfolder' in image and image['subfolder'] != "": | |
| imgpath = image['subfolder'] + "/" | |
| imgpath += f"{image['filename']}" | |
| if 'type' in image and image['type'] != "": | |
| imgpath += f" [{image['type']}]" | |
| res['ui']['aux'] = [id(images), [forward]] | |
| _, mask = nodes.LoadImage().load_image(imgpath) | |
| res['result'] = (mask, ) | |
| return res | |