Spaces:
Runtime error
Runtime error
| from functools import partial | |
| from typing import Iterator, Union, List, Mapping, Literal | |
| from PIL import Image | |
| from imgutils.tagging import get_deepdanbooru_tags, get_wd14_tags, get_mldanbooru_tags | |
| from .base import ProcessAction, BaseAction | |
| from ..model import ImageItem | |
| def _deepdanbooru_tagging(image: Image.Image, use_real_name: bool = False, | |
| general_threshold: float = 0.5, character_threshold: float = 0.5, **kwargs): | |
| _ = kwargs | |
| _, features, characters = get_deepdanbooru_tags(image, use_real_name, general_threshold, character_threshold) | |
| return {**features, **characters} | |
| def _wd14_tagging(image: Image.Image, model_name: str, | |
| general_threshold: float = 0.35, character_threshold: float = 0.85, **kwargs): | |
| _ = kwargs | |
| _, features, characters = get_wd14_tags(image, model_name, general_threshold, character_threshold) | |
| return {**features, **characters} | |
| def _mldanbooru_tagging(image: Image.Image, use_real_name: bool = False, general_threshold: float = 0.7, **kwargs): | |
| _ = kwargs | |
| features = get_mldanbooru_tags(image, use_real_name, general_threshold) | |
| return features | |
| _TAGGING_METHODS = { | |
| 'deepdanbooru': _deepdanbooru_tagging, | |
| 'wd14_vit': partial(_wd14_tagging, model_name='ViT'), | |
| 'wd14_convnext': partial(_wd14_tagging, model_name='ConvNext'), | |
| 'wd14_convnextv2': partial(_wd14_tagging, model_name='ConvNextV2'), | |
| 'wd14_swinv2': partial(_wd14_tagging, model_name='SwinV2'), | |
| 'mldanbooru': _mldanbooru_tagging, | |
| } | |
| TaggingMethodTyping = Literal[ | |
| 'deepdanbooru', 'wd14_vit', 'wd14_convnext', 'wd14_convnextv2', 'wd14_swinv2', 'mldanbooru'] | |
| class TaggingAction(ProcessAction): | |
| def __init__(self, method: TaggingMethodTyping = 'wd14_convnextv2', force: bool = False, **kwargs): | |
| self.method = _TAGGING_METHODS[method] | |
| self.force = force | |
| self.kwargs = kwargs | |
| def process(self, item: ImageItem) -> ImageItem: | |
| if 'tags' in item.meta and not self.force: | |
| return item | |
| else: | |
| tags = self.method(image=item.image, **self.kwargs) | |
| return ImageItem(item.image, {**item.meta, 'tags': tags}) | |
| class TagFilterAction(BaseAction): | |
| def __init__(self, tags: Union[List[str], Mapping[str, float]], | |
| method: TaggingMethodTyping = 'wd14_convnextv2', **kwargs): | |
| if isinstance(tags, (list, tuple)): | |
| self.tags = {tag: 1e-6 for tag in tags} | |
| elif isinstance(tags, dict): | |
| self.tags = dict(tags) | |
| else: | |
| raise TypeError(f'Unknown type of tags - {tags!r}.') | |
| self.tagger = TaggingAction(method, force=False, **kwargs) | |
| def iter(self, item: ImageItem) -> Iterator[ImageItem]: | |
| item = self.tagger(item) | |
| tags = item.meta['tags'] | |
| valid = True | |
| for tag, min_score in self.tags.items(): | |
| if tags[tag] < min_score: | |
| valid = False | |
| break | |
| if valid: | |
| yield item | |
| def reset(self): | |
| self.tagger.reset() | |