Spaces:
Running
Running
| from io import BytesIO | |
| from threading import Lock | |
| import numpy as np | |
| import torch | |
| from torch import Tensor | |
| from torch.nn import Parameter | |
| import spaces | |
| from huggingface_hub import hf_hub_download | |
| import gradio as gr | |
| from PIL import Image, ImageDraw, ImageFont | |
| import requests | |
| from model import load_model, process_image, patchify_image | |
| from image import unpatchify | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| PATCH_SIZE = 16 | |
| MAX_SEQ_LEN = 1024 | |
| model_lock = Lock() | |
| model, tag_list = load_model( | |
| hf_hub_download(repo_id="RedRocket/JTP-3", filename="models/jtp-3-hydra.safetensors"), | |
| device=device | |
| ) | |
| model.requires_grad_(False) | |
| tags = { | |
| tag.replace("_", " ").replace("vulva", "pussy"): idx | |
| for idx, tag in enumerate(tag_list) | |
| } | |
| tag_list = list(tags.keys()) | |
| FONT = ImageFont.load_default(24) | |
| def run_classifier(image: Image.Image, cam_depth: int): | |
| patches, patch_coords, patch_valid = patchify_image(image, PATCH_SIZE, MAX_SEQ_LEN) | |
| patches = patches.unsqueeze(0).to(device=device, non_blocking=True) | |
| patch_coords = patch_coords.unsqueeze(0).to(device=device, non_blocking=True) | |
| patch_valid = patch_valid.unsqueeze(0).to(device=device, non_blocking=True) | |
| patches = patches.to(dtype=torch.bfloat16).div_(127.5).sub_(1.0) | |
| patch_coords = patch_coords.to(dtype=torch.int32) | |
| with model_lock: | |
| features = model.forward_intermediates( | |
| patches, | |
| patch_coord=patch_coords, | |
| patch_valid=patch_valid, | |
| indices=cam_depth, | |
| output_dict=True, | |
| output_fmt='NLC' | |
| ) | |
| logits = model.forward_head(features["image_features"], patch_valid=patch_valid) | |
| del features["image_features"] | |
| features["patch_coords"] = patch_coords | |
| features["patch_valid"] = patch_valid | |
| del patches, patch_coords, patch_valid | |
| probits = logits[0].float().sigmoid_().mul_(2.0).sub_(1.0) # scale to -1 to 1 | |
| values, indices = probits.cpu().topk(250) | |
| predictions = { | |
| tag_list[idx.item()]: val.item() | |
| for idx, val in sorted( | |
| zip(indices, values), | |
| key=lambda item: item[1].item(), | |
| reverse=True | |
| ) | |
| } | |
| return features, predictions | |
| def run_cam( | |
| display_image: Image.Image, | |
| image: Image.Image, features: dict[str, Tensor], | |
| tag_idx: int, cam_depth: int | |
| ): | |
| intermediates = features["image_intermediates"] | |
| if len(intermediates) < cam_depth: | |
| features, _ = run_classifier(image, cam_depth) | |
| intermediates = features["image_intermediates"] | |
| elif len(intermediates) > cam_depth: | |
| intermediates = intermediates[-cam_depth:] | |
| patch_coords = features["patch_coords"] | |
| patch_valid = features["patch_valid"] | |
| with model_lock: | |
| saved_q = model.attn_pool.q | |
| saved_p = model.attn_pool.out_proj.weight | |
| try: | |
| model.attn_pool.q = Parameter(saved_q[:, [tag_idx], :], requires_grad=False) | |
| model.attn_pool.out_proj.weight = Parameter(saved_p[[tag_idx], :, :], requires_grad=False) | |
| with torch.enable_grad(): | |
| for intermediate in intermediates: | |
| intermediate.requires_grad_(True).retain_grad() | |
| model.forward_head(intermediate, patch_valid=patch_valid)[0, 0].backward() | |
| finally: | |
| model.attn_pool.q = saved_q | |
| model.attn_pool.out_proj.weight = saved_p | |
| cam_1d: Tensor | None = None | |
| for intermediate in intermediates: | |
| patch_grad = (intermediate.grad.float() * intermediate.sign()).sum(dim=(0, 2)) | |
| intermediate.grad = None | |
| if cam_1d is None: | |
| cam_1d = patch_grad | |
| else: | |
| cam_1d.add_(patch_grad) | |
| assert cam_1d is not None | |
| cam_2d = unpatchify(cam_1d, patch_coords, patch_valid).cpu().numpy() | |
| return cam_composite(display_image, cam_2d), features | |
| def cam_composite(image: Image.Image, cam: np.ndarray): | |
| """ | |
| Overlays CAM on image and returns a PIL image. | |
| Args: | |
| image_pil: PIL Image (RGB) | |
| cam: 2D numpy array (activation map) | |
| Returns: | |
| PIL.Image.Image with overlay | |
| """ | |
| cam_abs = np.abs(cam) | |
| cam_scale = cam_abs.max() | |
| cam_rgba = np.dstack(( | |
| (cam < 0).astype(np.float32), | |
| (cam > 0).astype(np.float32), | |
| np.zeros_like(cam, dtype=np.float32), | |
| cam_abs * (0.5 / cam_scale), | |
| )) # Shape: (H, W, 4) | |
| cam_pil = Image.fromarray((cam_rgba * 255).astype(np.uint8)) | |
| cam_pil = cam_pil.resize(image.size, resample=Image.Resampling.NEAREST) | |
| image = Image.blend( | |
| image.convert('RGBA'), | |
| image.convert('L').convert('RGBA'), | |
| 0.33 | |
| ) | |
| image = Image.alpha_composite(image, cam_pil) | |
| draw = ImageDraw.Draw(image) | |
| draw.text( | |
| (image.width - 7, image.height - 7), | |
| f"{cam_scale.item():.4g}", | |
| anchor="rd", font=FONT, fill=(32, 32, 255, 255) | |
| ) | |
| return image | |
| def filter_tags(predictions: dict[str, float], threshold: float): | |
| predictions = { | |
| key: value | |
| for key, value in predictions.items() | |
| if value >= threshold | |
| } | |
| tag_str = ", ".join(predictions.keys()) | |
| return tag_str, predictions | |
| def resize_image(image: Image.Image) -> Image.Image: | |
| longest_side = max(image.height, image.width) | |
| if longest_side < 1080: | |
| return image | |
| scale = 1080 / longest_side | |
| return image.resize( | |
| ( | |
| int(round(image.width * scale)), | |
| int(round(image.height * scale)), | |
| ), | |
| resample=Image.Resampling.LANCZOS, | |
| reducing_gap=3.0 | |
| ) | |
| def image_upload(image: Image.Image): | |
| display_image = resize_image(image) | |
| processed_image = process_image(image, PATCH_SIZE, MAX_SEQ_LEN) | |
| if display_image is not image and processed_image is not image: | |
| image.close() | |
| return ( | |
| "", {}, "None", "", | |
| gr.skip() if display_image is image else display_image, display_image, | |
| processed_image, | |
| ) | |
| def url_submit(url: str): | |
| resp = requests.get(url, timeout=10) | |
| resp.raise_for_status() | |
| image = Image.open(BytesIO(resp.content)) | |
| display_image = resize_image(image) | |
| processed_image = process_image(image, PATCH_SIZE, MAX_SEQ_LEN) | |
| if display_image is not image and processed_image is not image: | |
| image.close() | |
| return ( | |
| "", {}, "None", | |
| display_image, display_image, | |
| processed_image, | |
| ) | |
| def image_changed(image: Image.Image, threshold: float, cam_depth: int): | |
| features, predictions = run_classifier(image, cam_depth) | |
| return *filter_tags(predictions, threshold), features, predictions | |
| def image_clear(): | |
| return ( | |
| "", {}, "None", "", | |
| None, None, | |
| None, None, {}, | |
| ) | |
| def cam_changed( | |
| display_image: Image.Image, | |
| image: Image.Image, features: dict[str, Tensor], | |
| tag: str, cam_depth: int | |
| ): | |
| if tag == "None": | |
| return display_image, features | |
| return run_cam(display_image, image, features, tags[tag], cam_depth) | |
| def tag_box_select(evt: gr.SelectData): | |
| return evt.value | |
| custom_css = """ | |
| .output-class { display: none; } | |
| .inferno-slider input[type=range] { | |
| background: linear-gradient(to right, | |
| #000004, #1b0c41, #4a0c6b, #781c6d, | |
| #a52c60, #cf4446, #ed6925, #fb9b06, | |
| #f7d13d, #fcffa4 | |
| ) !important; | |
| background-size: 100% 100% !important; | |
| } | |
| #image_container-image { | |
| width: 100%; | |
| aspect-ratio: 1 / 1; | |
| max-height: 100%; | |
| } | |
| #image_container img { | |
| object-fit: contain !important; | |
| } | |
| .show-api, .show-api-divider { | |
| display: none !important; | |
| } | |
| """ | |
| with gr.Blocks( | |
| title="RedRocket JTP-3 Hydra Demo", | |
| css=custom_css, | |
| analytics_enabled=False, | |
| ) as demo: | |
| display_image_state = gr.State() | |
| image_state = gr.State() | |
| features_state = gr.State() | |
| predictions_state = gr.State(value={}) | |
| gr.HTML( | |
| "<h1 style='display:flex; flex-flow: row nowrap; align-items: center;'>" | |
| "<a href='https://huggingface.co/RedRocket' target='_blank'>" | |
| "<img src='https://huggingface.co/spaces/RedRocket/README/resolve/main/RedRocket.png' style='width: 2em; margin-right: 0.5em;'>" | |
| "</a>" | |
| "<span><a href='https://huggingface.co/RedRocket' target='_blank'>RedRocket</a> – JTP-3 Hydra Demo</span>" | |
| "<span style='font-weight: normal;'> • <a href='https://huggingface.co/RedRocket/JTP-3' target='_blank'>Download</a></span>" | |
| "</h1>" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Column(): | |
| image = gr.Image( | |
| sources=['upload', 'clipboard'], type='pil', | |
| show_label=False, | |
| show_download_button=False, | |
| show_share_button=False, | |
| elem_id="image_container" | |
| ) | |
| url = gr.Textbox( | |
| label="Upload Image via Url:", | |
| placeholder="https://example.com/image.jpg", | |
| max_lines=1, | |
| submit_btn="⮝", | |
| ) | |
| with gr.Column(): | |
| cam_tag = gr.Dropdown( | |
| value="None", choices=["None"] + tag_list, | |
| label="CAM Attention Overlay (You can also click a tag on the right.)", show_label=True | |
| ) | |
| cam_depth = gr.Slider( | |
| minimum=1, maximum=27, step=1, value=1, | |
| label="CAM Depth (1=fastest, more precise; 27=slowest, more general)" | |
| ) | |
| with gr.Column(): | |
| threshold_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.30, label="Tag Threshold") | |
| tag_string = gr.Textbox(lines=3, label="Tags", show_label=True, show_copy_button=True) | |
| tag_box = gr.Label(num_top_classes=250, show_label=False, show_heading=False) | |
| image.upload( | |
| fn=image_upload, | |
| inputs=[image], | |
| outputs=[ | |
| tag_string, tag_box, cam_tag, url, | |
| image, display_image_state, | |
| image_state, | |
| ], | |
| show_progress='minimal', | |
| show_progress_on=[image] | |
| ).then( | |
| fn=image_changed, | |
| inputs=[image_state, threshold_slider, cam_depth], | |
| outputs=[ | |
| tag_string, tag_box, | |
| features_state, predictions_state, | |
| ], | |
| show_progress='minimal', | |
| show_progress_on=[tag_box] | |
| ) | |
| url.submit( | |
| fn=url_submit, | |
| inputs=[url], | |
| outputs=[ | |
| tag_string, tag_box, cam_tag, | |
| image, display_image_state, | |
| image_state, | |
| ], | |
| show_progress='minimal', | |
| show_progress_on=[url] | |
| ).then( | |
| fn=image_changed, | |
| inputs=[image_state, threshold_slider, cam_depth], | |
| outputs=[ | |
| tag_string, tag_box, | |
| features_state, predictions_state, | |
| ], | |
| show_progress='minimal', | |
| show_progress_on=[tag_box] | |
| ) | |
| image.clear( | |
| fn=image_clear, | |
| inputs=[], | |
| outputs=[ | |
| tag_string, tag_box, cam_tag, url, | |
| image, display_image_state, | |
| image_state, features_state, predictions_state, | |
| ], | |
| show_progress='hidden' | |
| ) | |
| threshold_slider.input( | |
| fn=filter_tags, | |
| inputs=[predictions_state, threshold_slider], | |
| outputs=[tag_string, tag_box], | |
| trigger_mode='always_last', | |
| show_progress='hidden' | |
| ) | |
| cam_tag.input( | |
| fn=cam_changed, | |
| inputs=[ | |
| display_image_state, | |
| image_state, features_state, | |
| cam_tag, cam_depth, | |
| ], | |
| outputs=[image, features_state], | |
| trigger_mode='always_last', | |
| show_progress='minimal', | |
| show_progress_on=[cam_tag] | |
| ) | |
| cam_depth.input( | |
| fn=cam_changed, | |
| inputs=[ | |
| display_image_state, | |
| image_state, features_state, | |
| cam_tag, cam_depth, | |
| ], | |
| outputs=[image, features_state], | |
| trigger_mode='always_last', | |
| show_progress='minimal', | |
| show_progress_on=[cam_depth] | |
| ) | |
| tag_box.select( | |
| fn=tag_box_select, | |
| inputs=[], | |
| outputs=[cam_tag], | |
| trigger_mode='always_last', | |
| show_progress='hidden', | |
| ).then( | |
| fn=cam_changed, | |
| inputs=[ | |
| display_image_state, | |
| image_state, features_state, | |
| cam_tag, cam_depth, | |
| ], | |
| outputs=[image, features_state], | |
| show_progress='minimal', | |
| show_progress_on=[cam_tag] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |