Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # Copyright (c) Meta Platforms, Inc. All Rights Reserved | |
| # Modified by Feng Liang from | |
| # https://github.com/MendelXu/zsseg.baseline/blob/master/mask_former/modeling/clip_adapter/text_prompt.py | |
| # https://github.com/MendelXu/zsseg.baseline/blob/master/mask_former/modeling/clip_adapter/utils.py | |
| from typing import List | |
| import clip | |
| import torch | |
| from torch import nn | |
| IMAGENET_PROMPT = [ | |
| "a bad photo of a {}.", | |
| "a photo of many {}.", | |
| "a sculpture of a {}.", | |
| "a photo of the hard to see {}.", | |
| "a low resolution photo of the {}.", | |
| "a rendering of a {}.", | |
| "graffiti of a {}.", | |
| "a bad photo of the {}.", | |
| "a cropped photo of the {}.", | |
| "a tattoo of a {}.", | |
| "the embroidered {}.", | |
| "a photo of a hard to see {}.", | |
| "a bright photo of a {}.", | |
| "a photo of a clean {}.", | |
| "a photo of a dirty {}.", | |
| "a dark photo of the {}.", | |
| "a drawing of a {}.", | |
| "a photo of my {}.", | |
| "the plastic {}.", | |
| "a photo of the cool {}.", | |
| "a close-up photo of a {}.", | |
| "a black and white photo of the {}.", | |
| "a painting of the {}.", | |
| "a painting of a {}.", | |
| "a pixelated photo of the {}.", | |
| "a sculpture of the {}.", | |
| "a bright photo of the {}.", | |
| "a cropped photo of a {}.", | |
| "a plastic {}.", | |
| "a photo of the dirty {}.", | |
| "a jpeg corrupted photo of a {}.", | |
| "a blurry photo of the {}.", | |
| "a photo of the {}.", | |
| "a good photo of the {}.", | |
| "a rendering of the {}.", | |
| "a {} in a video game.", | |
| "a photo of one {}.", | |
| "a doodle of a {}.", | |
| "a close-up photo of the {}.", | |
| "a photo of a {}.", | |
| "the origami {}.", | |
| "the {} in a video game.", | |
| "a sketch of a {}.", | |
| "a doodle of the {}.", | |
| "a origami {}.", | |
| "a low resolution photo of a {}.", | |
| "the toy {}.", | |
| "a rendition of the {}.", | |
| "a photo of the clean {}.", | |
| "a photo of a large {}.", | |
| "a rendition of a {}.", | |
| "a photo of a nice {}.", | |
| "a photo of a weird {}.", | |
| "a blurry photo of a {}.", | |
| "a cartoon {}.", | |
| "art of a {}.", | |
| "a sketch of the {}.", | |
| "a embroidered {}.", | |
| "a pixelated photo of a {}.", | |
| "itap of the {}.", | |
| "a jpeg corrupted photo of the {}.", | |
| "a good photo of a {}.", | |
| "a plushie {}.", | |
| "a photo of the nice {}.", | |
| "a photo of the small {}.", | |
| "a photo of the weird {}.", | |
| "the cartoon {}.", | |
| "art of the {}.", | |
| "a drawing of the {}.", | |
| "a photo of the large {}.", | |
| "a black and white photo of a {}.", | |
| "the plushie {}.", | |
| "a dark photo of a {}.", | |
| "itap of a {}.", | |
| "graffiti of the {}.", | |
| "a toy {}.", | |
| "itap of my {}.", | |
| "a photo of a cool {}.", | |
| "a photo of a small {}.", | |
| "a tattoo of the {}.", | |
| ] | |
| VILD_PROMPT = [ | |
| "a photo of a {}.", | |
| "This is a photo of a {}", | |
| "There is a {} in the scene", | |
| "There is the {} in the scene", | |
| "a photo of a {} in the scene", | |
| "a photo of a small {}.", | |
| "a photo of a medium {}.", | |
| "a photo of a large {}.", | |
| "This is a photo of a small {}.", | |
| "This is a photo of a medium {}.", | |
| "This is a photo of a large {}.", | |
| "There is a small {} in the scene.", | |
| "There is a medium {} in the scene.", | |
| "There is a large {} in the scene.", | |
| ] | |
| class PromptExtractor(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self._buffer_init = False | |
| def init_buffer(self, clip_model): | |
| self._buffer_init = True | |
| def forward(self, noun_list: List[str], clip_model: nn.Module): | |
| raise NotImplementedError() | |
| class PredefinedPromptExtractor(PromptExtractor): | |
| def __init__(self, templates: List[str]): | |
| super().__init__() | |
| self.templates = templates | |
| def forward(self, noun_list: List[str], clip_model: nn.Module): | |
| text_features_bucket = [] | |
| for template in self.templates: | |
| noun_tokens = [clip.tokenize(template.format(noun)) for noun in noun_list] | |
| text_inputs = torch.cat(noun_tokens).to( | |
| clip_model.text_projection.data.device | |
| ) | |
| text_features = clip_model.encode_text(text_inputs) | |
| text_features /= text_features.norm(dim=-1, keepdim=True) | |
| text_features_bucket.append(text_features) | |
| del text_inputs | |
| # ensemble by averaging | |
| text_features = torch.stack(text_features_bucket).mean(dim=0) | |
| text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
| return text_features | |
| class ImageNetPromptExtractor(PredefinedPromptExtractor): | |
| def __init__(self): | |
| super().__init__(IMAGENET_PROMPT) | |
| class VILDPromptExtractor(PredefinedPromptExtractor): | |
| def __init__(self): | |
| super().__init__(VILD_PROMPT) | |