File size: 2,230 Bytes
a4b70d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from __future__ import annotations

from typing import Optional

from .helper import format_media_prompt
from ..typing import AsyncResult, Messages, MediaListType
from ..config import STATIC_URL
from .PollinationsAI import PollinationsAI

class PollinationsImage(PollinationsAI):
    label = "PollinationsImage"
    parent = PollinationsAI.__name__
    active_by_default = False
    default_model = "flux"
    default_vision_model = None
    default_image_model = default_model
    audio_models = {}

    @classmethod
    def get_models(cls, **kwargs):
        PollinationsAI.get_models()
        cls.image_models = PollinationsAI.image_models
        cls.models = cls.image_models
        return cls.models

    @classmethod
    def get_grouped_models(cls) -> dict[str, list[str]]:
        PollinationsAI.get_models()
        return [
            {"group": "Image Generation", "models": PollinationsAI.image_models},
        ]

    @classmethod
    async def create_async_generator(
        cls,
        model: str,
        messages: Messages,
        media: MediaListType = None,
        proxy: str = None,
        referrer: str = STATIC_URL,
        api_key: str = None,
        prompt: str = None,
        aspect_ratio: str = None,
        width: int = None,
        height: int = None,
        seed: Optional[int] = None,
        cache: bool = False,
        nologo: bool = True,
        private: bool = False,
        enhance: bool = False,
        safe: bool = False,
        transparent: bool = False,
        n: int = 1,
        **kwargs
    ) -> AsyncResult:
        # Calling model updates before creating a generator
        cls.get_models()
        async for chunk in cls._generate_image(
            model=model,
            prompt=format_media_prompt(messages, prompt),
            media=media,
            proxy=proxy,
            aspect_ratio=aspect_ratio,
            width=width,
            height=height,
            seed=seed,
            cache=cache,
            nologo=nologo,
            private=private,
            enhance=enhance,
            safe=safe,
            transparent=transparent,
            n=n,
            referrer=referrer,
            api_key=api_key
        ):
            yield chunk