File size: 6,711 Bytes
a4b70d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from __future__ import annotations

import os
import uuid

from ...typing import AsyncResult, Messages, MediaListType
from ...providers.response import ImageResponse, JsonConversation, Reasoning
from ...requests import StreamSession, FormData, sse_stream
from ...tools.media import merge_media
from ...image import to_bytes, is_accepted_format
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import format_media_prompt
from .DeepseekAI_JanusPro7b import get_zerogpu_token
from .raise_for_status import raise_for_status

class BlackForestLabs_Flux1KontextDev(AsyncGeneratorProvider, ProviderModelMixin):
    label = "BlackForestLabs Flux-1-Kontext-Dev"
    url = "https://black-forest-labs-flux-1-kontext-dev.hf.space"
    space = "black-forest-labs/FLUX.1-Kontext-Dev"
    referer = f"{url}/?__theme=system"
    working = True

    default_model = "flux-kontext-dev"
    default_image_model = default_model
    image_models = [default_model]
    models = image_models

    @classmethod
    def run(cls, method: str, session: StreamSession, conversation: JsonConversation, data: list = None):
        headers = {
            # Different accept header based on GET or POST
            "accept": "application/json" if method == "post" else "text/event-stream",
            "content-type": "application/json",
            "x-zerogpu-token": conversation.zerogpu_token,
            "x-zerogpu-uuid": conversation.zerogpu_uuid,
            "referer": cls.referer,
        }
        # Filter out headers where value is None (e.g., token not yet set)
        filtered_headers = {k: v for k, v in headers.items() if v is not None}

        if method == "post":
            # POST request to enqueue the job
            return session.post(f"{cls.url}/gradio_api/queue/join?__theme=system", **{
                "headers": filtered_headers,
                "json": {
                    "data": data,
                    "event_data": None,
                    "fn_index": 2,
                    "trigger_id": 7,      # Using trigger_id=7 per your example fetch
                    "session_hash": conversation.session_hash
                }
            })

        # GET request to receive the event stream result
        return session.get(f"{cls.url}/gradio_api/queue/data?session_hash={conversation.session_hash}", **{
            "headers": filtered_headers,
        })

    @classmethod
    async def create_async_generator(
        cls,
        model: str,
        messages: Messages,
        prompt: str = None,
        media: MediaListType = None,
        proxy: str = None,
        guidance_scale: float = 2.5,
        num_inference_steps: int = 28,
        seed: int = 0,
        randomize_seed: bool = True,
        cookies: dict = None,
        api_key: str = None,
        zerogpu_uuid: str = None,
        **kwargs
    ) -> AsyncResult:
        # Create a conversation/session data container holding tokens and session hash
        conversation = JsonConversation(
            zerogpu_token=api_key,
            zerogpu_uuid=zerogpu_uuid or uuid.uuid4().hex,
            session_hash=uuid.uuid4().hex,
        )
        async with StreamSession(impersonate="chrome", proxy=proxy) as session:
            media = list(merge_media(media, messages))
            if media:
                data = FormData()
                for i in range(len(media)):
                    if media[i][1] is None and isinstance(media[i][0], str):
                        media[i] = media[i][0], os.path.basename(media[i][0])
                    media[i] = (to_bytes(media[i][0]), media[i][1])
                for image, image_name in media:
                    data.add_field(f"files", image, filename=image_name)
                async with session.post(f"{cls.url}/gradio_api/upload", params={"upload_id": conversation.session_hash}, data=data) as response:
                    await raise_for_status(response)
                    image_files = await response.json()
                media = [{
                    "path": image_file,
                    "url": f"{cls.url}/gradio_api/file={image_file}",
                    "orig_name": media[i][1],
                    "size": len(media[i][0]),
                    "mime_type": is_accepted_format(media[i][0]),
                    "meta": {
                        "_type": "gradio.FileData"
                    }
                } for i, image_file in enumerate(image_files)]
            if not media:
                raise ValueError("No media files provided for image generation.")

            # Format the prompt from messages, e.g. extract text or media description
            prompt = format_media_prompt(messages, prompt)

            # Build the data payload sent to the API
            data = [
                media.pop(),
                prompt,
                seed,                 
                randomize_seed,
                guidance_scale,
                num_inference_steps,
            ]

            # Fetch token if it's missing (calls a helper function to obtain a token)
            if conversation.zerogpu_token is None:
                conversation.zerogpu_uuid, conversation.zerogpu_token = await get_zerogpu_token(
                    cls.space, session, conversation, cookies
                )

            # POST the prompt and data to start generation job in the queue
            async with cls.run("post", session, conversation, data) as response:
                await raise_for_status(response)
                result_json = await response.json()
                assert result_json.get("event_id")  # Ensure we got an event id back

            # GET the event stream to receive updates and results asynchronously
            async with cls.run("get", session, conversation) as event_response:
                await raise_for_status(event_response)
                async for chunk in sse_stream(event_response):
                    if chunk.get("msg") == "process_starts":
                        yield Reasoning(label="Processing started")
                    elif chunk.get("msg") == "progress":
                        progress_data = chunk.get("progress_data", [])
                        progress_data = progress_data[0] if progress_data else {}
                        yield Reasoning(label="Processing image", status=f"{progress_data.get('index', 0)}/{progress_data.get('length', 0)}")
                    elif chunk.get("msg") == "process_completed":
                        url = chunk.get("output", {}).get("data", [{}])[0].get("url")
                        yield ImageResponse(url, prompt)
                        yield Reasoning(label="Completed", status="")
                        break