File size: 6,230 Bytes
a4b70d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from __future__ import annotations

import os
import json

from ...typing import Messages, AsyncResult, MediaListType
from ...errors import MissingAuthError, ModelNotFoundError
from ...requests import StreamSession, FormData, raise_for_status
from ...image import get_width_height, to_bytes
from ...image.copy_images import save_response_media
from ..template import OpenaiTemplate
from ..helper import format_media_prompt

class Azure(OpenaiTemplate):
    label = "Azure ☁️"
    url = "https://ai.azure.com"
    api_base = "https://host.g4f.dev/api/Azure"
    working = True
    needs_auth = True
    models_needs_auth = True
    active_by_default = True
    login_url = "https://discord.gg/qXA4Wf4Fsm"
    routes: dict[str, str] = {}
    audio_models = ["gpt-4o-mini-audio-preview"]
    vision_models = ["gpt-4.1", "o4-mini", "model-router", "flux.1-kontext-pro"]
    image_models = ["flux-1.1-pro", "flux.1-kontext-pro"]
    model_aliases = {
        "flux-kontext": "flux.1-kontext-pro"
    }
    model_extra_body = {
        "gpt-4o-mini-audio-preview": {
            "audio": {
                "voice": "alloy",
                "format": "mp3"
            },
            "modalities": ["text", "audio"],
        }
    }
    api_keys: dict[str, str] = {}
    failed: dict[str, int] = {}

    @classmethod
    def get_models(cls, api_key: str = None, **kwargs) -> list[str]:
        api_keys = os.environ.get("AZURE_API_KEYS")
        if api_keys:
            try:
                cls.api_keys = json.loads(api_keys)
            except json.JSONDecodeError:
                raise ValueError(f"Invalid AZURE_API_KEYS environment variable")
        routes = os.environ.get("AZURE_ROUTES")
        if routes:
            try:
                routes = json.loads(routes)
            except json.JSONDecodeError:
                raise ValueError(f"Invalid AZURE_ROUTES environment variable format: {routes}")
            cls.routes = routes
        if cls.routes:
            if cls.live == 0 and cls.api_keys:
                cls.live += 1
            return list(cls.routes.keys())
        return super().get_models(api_key=api_key, **kwargs)

    @classmethod
    async def create_async_generator(
        cls,
        model: str,
        messages: Messages,
        stream: bool = True,
        media: MediaListType = None,
        api_key: str = None,
        api_endpoint: str = None,
        **kwargs
    ) -> AsyncResult:
        if not model:
            model = os.environ.get("AZURE_DEFAULT_MODEL", cls.default_model)
        if model in cls.model_aliases:
            model = cls.model_aliases[model]
        if not api_endpoint:
            if not cls.routes:
                cls.get_models()
            api_endpoint = cls.routes.get(model)
            if cls.routes and not api_endpoint:
                raise ModelNotFoundError(f"No API endpoint found for model: {model}")
        if not api_endpoint:
            api_endpoint = os.environ.get("AZURE_API_ENDPOINT")
        if cls.api_keys:
            api_key = cls.api_keys.get(model, cls.api_keys.get("default"))
            if not api_key:
                raise ValueError(f"API key is required for Azure provider. Ask for API key in the {cls.login_url} Discord server.")
        if api_endpoint and "/images/" in api_endpoint:
            prompt = format_media_prompt(messages, kwargs.get("prompt"))
            width, height = get_width_height(kwargs.get("aspect_ratio", "1:1"), kwargs.get("width"), kwargs.get("height"))
            output_format = kwargs.get("output_format", "png")
            form = None
            data = None
            if media:
                form = FormData()
                form.add_field("prompt", prompt)
                form.add_field("width", str(width))
                form.add_field("height", str(height))
                output_format = "png"
                for i in range(len(media)):
                    if media[i][1] is None and isinstance(media[i][0], str):
                        media[i] = media[i][0], os.path.basename(media[i][0])
                    media[i] = (to_bytes(media[i][0]), media[i][1])
                for image, image_name in media:
                    form.add_field(f"image", image, filename=image_name)
            else:
                api_endpoint = api_endpoint.replace("/edits", "/generations")
                data = {
                    "prompt": prompt,
                    "n": 1,
                    "width": width,
                    "height": height,
                    "output_format": output_format,
                }
            async with StreamSession(proxy=kwargs.get("proxy"), headers={
                "Authorization": f"Bearer {api_key}",
                "x-ms-model-mesh-model-name": model,
            }) as session:
                async with session.post(api_endpoint, data=form, json=data) as response:
                    data = await response.json()
                    await raise_for_status(response, data)
                    async for chunk in save_response_media(
                        data["data"][0]["b64_json"],
                        prompt,
                        content_type=f"image/{output_format.replace('jpg', 'jpeg')}"
                    ):
                        yield chunk
            return
        if model in cls.model_extra_body:
            for key, value in cls.model_extra_body[model].items():
                kwargs.setdefault(key, value)
            stream = False
        if stream:
            kwargs.setdefault("stream_options", {"include_usage": True})
        if cls.failed.get(api_key, 0) >= 3:
            raise MissingAuthError(f"API key has failed too many times.")
        try:
            async for chunk in super().create_async_generator(
                model=model,
                messages=messages,
                stream=stream,
                media=media,
                api_key=api_key,
                api_endpoint=api_endpoint,
                **kwargs
            ):
                yield chunk
        except MissingAuthError as e:
            cls.failed[api_key] = cls.failed.get(api_key, 0) + 1
            raise MissingAuthError(f"{e}. Ask for help in the {cls.login_url} Discord server.") from e