Spaces:
Runtime error
Runtime error
| # Adopted from https://github.com/magic-research/Sa2VA/blob/main/projects/llava_sam2/models/extension/sam2_base.py. | |
| # Below is the original copyright: | |
| # coding=utf-8 | |
| # Copyright 2024 The HuggingFace Inc. team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import torch | |
| import torch.nn.functional as F | |
| from third_parts.sam2.modeling.sam2_base import SAM2Base as _SAM2Base | |
| from third_parts.sam2.modeling.sam2_base import NO_OBJ_SCORE | |
| class SAM2Base(_SAM2Base): | |
| def track_step( | |
| self, | |
| frame_idx, | |
| is_init_cond_frame, | |
| current_vision_feats, | |
| current_vision_pos_embeds, | |
| feat_sizes, | |
| point_inputs, | |
| mask_inputs, | |
| output_dict, | |
| num_frames, | |
| track_in_reverse=False, # tracking in reverse time order (for demo usage) | |
| # Whether to run the memory encoder on the predicted masks. Sometimes we might want | |
| # to skip the memory encoder with `run_mem_encoder=False`. For example, | |
| # in demo we might call `track_step` multiple times for each user click, | |
| # and only encode the memory when the user finalizes their clicks. And in ablation | |
| # settings like SAM training on static images, we don't need the memory encoder. | |
| run_mem_encoder=True, | |
| # The previously predicted SAM mask logits (which can be fed together with new clicks in demo). | |
| prev_sam_mask_logits=None, | |
| ## Extension: LLM prompt | |
| language_embd=None, | |
| ): | |
| current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs} | |
| # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW | |
| if len(current_vision_feats) > 1: | |
| high_res_features = [ | |
| x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) | |
| for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1]) | |
| ] | |
| else: | |
| high_res_features = None | |
| if mask_inputs is not None and self.use_mask_input_as_output_without_sam: | |
| # When use_mask_input_as_output_without_sam=True, we directly output the mask input | |
| # (see it as a GT mask) without using a SAM prompt encoder + mask decoder. | |
| pix_feat = current_vision_feats[-1].permute(1, 2, 0) | |
| pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) | |
| sam_outputs = self._use_mask_as_output( | |
| pix_feat, high_res_features, mask_inputs | |
| ) | |
| else: | |
| # fused the visual feature with previous memory features in the memory bank | |
| pix_feat_with_mem = self._prepare_memory_conditioned_features( | |
| frame_idx=frame_idx, | |
| is_init_cond_frame=is_init_cond_frame, | |
| current_vision_feats=current_vision_feats[-1:], | |
| current_vision_pos_embeds=current_vision_pos_embeds[-1:], | |
| feat_sizes=feat_sizes[-1:], | |
| output_dict=output_dict, | |
| num_frames=num_frames, | |
| track_in_reverse=track_in_reverse, | |
| ) | |
| # apply SAM-style segmentation head | |
| # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder, | |
| # e.g. in demo where such logits come from earlier interaction instead of correction sampling | |
| # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead) | |
| if prev_sam_mask_logits is not None: | |
| assert point_inputs is not None and mask_inputs is None | |
| mask_inputs = prev_sam_mask_logits | |
| multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) | |
| sam_outputs = self._forward_sam_heads( | |
| backbone_features=pix_feat_with_mem, | |
| point_inputs=point_inputs, | |
| mask_inputs=mask_inputs, | |
| high_res_features=high_res_features, | |
| multimask_output=multimask_output, | |
| # Inject language Embed if possible | |
| language_embd=language_embd, | |
| ) | |
| ( | |
| _, | |
| _, | |
| _, | |
| low_res_masks, | |
| high_res_masks, | |
| obj_ptr, | |
| _, | |
| ) = sam_outputs | |
| current_out["pred_masks"] = low_res_masks | |
| current_out["pred_masks_high_res"] = high_res_masks | |
| current_out["obj_ptr"] = obj_ptr | |
| # Finally run the memory encoder on the predicted mask to encode | |
| # it into a new memory feature (that can be used in future frames) | |
| if run_mem_encoder and self.num_maskmem > 0: | |
| high_res_masks_for_mem_enc = high_res_masks | |
| maskmem_features, maskmem_pos_enc = self._encode_new_memory( | |
| current_vision_feats=current_vision_feats, | |
| feat_sizes=feat_sizes, | |
| pred_masks_high_res=high_res_masks_for_mem_enc, | |
| is_mask_from_pts=(point_inputs is not None), | |
| ) | |
| current_out["maskmem_features"] = maskmem_features | |
| current_out["maskmem_pos_enc"] = maskmem_pos_enc | |
| else: | |
| current_out["maskmem_features"] = None | |
| current_out["maskmem_pos_enc"] = None | |
| return current_out | |
| def _forward_sam_heads( | |
| self, | |
| backbone_features, | |
| point_inputs=None, | |
| mask_inputs=None, | |
| high_res_features=None, | |
| multimask_output=False, | |
| ## Extension: LLM prompt | |
| language_embd=None, | |
| ): | |
| """ | |
| Forward SAM prompt encoders and mask heads. | |
| Inputs: | |
| - backbone_features: image features of [B, C, H, W] shape | |
| - point_inputs: a dictionary with "point_coords" and "point_labels", where | |
| 1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the | |
| absolute pixel-unit coordinate in (x, y) format of the P input points | |
| 2) "point_labels" has shape [B, P] and int32 dtype, where 1 means | |
| positive clicks, 0 means negative clicks, and -1 means padding | |
| - mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the | |
| same spatial size as the image. | |
| - high_res_features: either 1) None or 2) or a list of length 2 containing | |
| two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively, | |
| which will be used as high-resolution feature maps for SAM decoder. | |
| - multimask_output: if it's True, we output 3 candidate masks and their 3 | |
| corresponding IoU estimates, and if it's False, we output only 1 mask and | |
| its corresponding IoU estimate. | |
| Outputs: | |
| - low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if | |
| `multimask_output=True` and M = 1 if `multimask_output=False`), the SAM | |
| output mask logits (before sigmoid) for the low-resolution masks, with 4x | |
| the resolution (1/4 stride) of the input backbone_features. | |
| - high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3 | |
| if `multimask_output=True` and M = 1 if `multimask_output=False`), | |
| upsampled from the low-resolution masks, with shape size as the image | |
| (stride is 1 pixel). | |
| - ious, [B, M] shape, where (where M = 3 if `multimask_output=True` and M = 1 | |
| if `multimask_output=False`), the estimated IoU of each output mask. | |
| - low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`. | |
| If `multimask_output=True`, it's the mask with the highest IoU estimate. | |
| If `multimask_output=False`, it's the same as `low_res_multimasks`. | |
| - high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`. | |
| If `multimask_output=True`, it's the mask with the highest IoU estimate. | |
| If `multimask_output=False`, it's the same as `high_res_multimasks`. | |
| - obj_ptr: [B, C] shape, the object pointer vector for the output mask, extracted | |
| based on the output token from the SAM mask decoder. | |
| """ | |
| B = backbone_features.size(0) | |
| device = backbone_features.device | |
| assert backbone_features.size(1) == self.sam_prompt_embed_dim | |
| assert backbone_features.size(2) == self.sam_image_embedding_size | |
| assert backbone_features.size(3) == self.sam_image_embedding_size | |
| # a) Handle point prompts | |
| if point_inputs is not None: | |
| sam_point_coords = point_inputs["point_coords"] | |
| sam_point_labels = point_inputs["point_labels"] | |
| assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B | |
| else: | |
| # If no points are provide, pad with an empty point (with label -1) | |
| sam_point_coords = torch.zeros(B, 1, 2, device=device) | |
| sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device) | |
| # b) Handle mask prompts | |
| if mask_inputs is not None: | |
| # If mask_inputs is provided, downsize it into low-res mask input if needed | |
| # and feed it as a dense mask prompt into the SAM mask encoder | |
| assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1) | |
| if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size: | |
| sam_mask_prompt = F.interpolate( | |
| mask_inputs.float(), | |
| size=self.sam_prompt_encoder.mask_input_size, | |
| align_corners=False, | |
| mode="bilinear", | |
| antialias=True, # use antialias for downsampling | |
| ) | |
| else: | |
| sam_mask_prompt = mask_inputs | |
| else: | |
| # Otherwise, simply feed None (and SAM's prompt encoder will add | |
| # a learned `no_mask_embed` to indicate no mask input in this case). | |
| sam_mask_prompt = None | |
| sparse_embeddings, dense_embeddings = self.sam_prompt_encoder( | |
| points=(sam_point_coords, sam_point_labels), | |
| boxes=None, | |
| masks=sam_mask_prompt, | |
| ) | |
| ## Extension: LLM prompt | |
| if language_embd is not None: | |
| # B N C | |
| # print('sparse_embeddings ', sparse_embeddings.shape, 'language_embd ', language_embd.shape) | |
| assert sparse_embeddings.size(0) == language_embd.size(0) | |
| assert sparse_embeddings.size(2) == language_embd.size(2) | |
| sparse_embeddings = torch.cat([sparse_embeddings, language_embd], dim=1) | |
| ( | |
| low_res_multimasks, | |
| ious, | |
| sam_output_tokens, | |
| object_score_logits, | |
| ) = self.sam_mask_decoder( | |
| image_embeddings=backbone_features, | |
| image_pe=self.sam_prompt_encoder.get_dense_pe(), | |
| sparse_prompt_embeddings=sparse_embeddings, | |
| dense_prompt_embeddings=dense_embeddings, | |
| multimask_output=multimask_output, | |
| repeat_image=False, # the image is already batched | |
| high_res_features=high_res_features, | |
| ) | |
| if self.pred_obj_scores: | |
| is_obj_appearing = object_score_logits > 0 | |
| # Mask used for spatial memories is always a *hard* choice between obj and no obj, | |
| # consistent with the actual mask prediction | |
| # print('Do torch.where !!!') | |
| # low_res_multimasks = torch.where( | |
| # is_obj_appearing[:, None, None], | |
| # low_res_multimasks, | |
| # NO_OBJ_SCORE, | |
| # ) | |
| # convert masks from possibly bfloat16 (or float16) to float32 | |
| # (older PyTorch versions before 2.1 don't support `interpolate` on bf16) | |
| low_res_multimasks = low_res_multimasks.float() | |
| high_res_multimasks = F.interpolate( | |
| low_res_multimasks, | |
| size=(self.image_size, self.image_size), | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| sam_output_token = sam_output_tokens[:, 0] | |
| if multimask_output: | |
| # take the best mask prediction (with the highest IoU estimation) | |
| best_iou_inds = torch.argmax(ious, dim=-1) | |
| batch_inds = torch.arange(B, device=device) | |
| low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) | |
| high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) | |
| if sam_output_tokens.size(1) > 1: | |
| sam_output_token = sam_output_tokens[batch_inds, best_iou_inds] | |
| else: | |
| low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks | |
| # Extract object pointer from the SAM output token (with occlusion handling) | |
| obj_ptr = self.obj_ptr_proj(sam_output_token) | |
| if self.pred_obj_scores: | |
| # Allow *soft* no obj ptr, unlike for masks | |
| if self.soft_no_obj_ptr: | |
| # Only hard possible with gt | |
| assert not self.teacher_force_obj_scores_for_mem | |
| lambda_is_obj_appearing = object_score_logits.sigmoid() | |
| else: | |
| lambda_is_obj_appearing = is_obj_appearing.float() | |
| if self.fixed_no_obj_ptr: | |
| obj_ptr = lambda_is_obj_appearing * obj_ptr | |
| obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr | |
| return ( | |
| low_res_multimasks, | |
| high_res_multimasks, | |
| ious, | |
| low_res_masks, | |
| high_res_masks, | |
| obj_ptr, | |
| object_score_logits, | |
| ) | |