File size: 3,541 Bytes
a4b70d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from __future__ import annotations

import os
import logging
from typing import Union, Optional, Coroutine

from . import debug, version

from .models import Model
from .client import Client, AsyncClient
from .typing import Messages, CreateResult, AsyncResult, ImageType
from .cookies import get_cookies, set_cookies
from .providers.types import ProviderType
from .providers.helper import concat_chunks, async_concat_chunks
from .client.service import get_model_and_provider

# Configure logger
logger = logging.getLogger("g4f")
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter(logging.BASIC_FORMAT))
logger.addHandler(handler)
logger.setLevel(logging.ERROR)


class ChatCompletion:
    @staticmethod
    def _prepare_request(model: Union[Model, str],
                         messages: Messages,
                         provider: Union[ProviderType, str, None],
                         stream: bool,
                         image: ImageType,
                         image_name: Optional[str],
                         ignore_working: bool,
                         ignore_stream: bool,
                         **kwargs):
        """Shared pre-processing for sync/async create methods."""
        if image is not None:
            kwargs["media"] = [(image, image_name)]
        elif "images" in kwargs:
            kwargs["media"] = kwargs.pop("images")

        model, provider = get_model_and_provider(
            model, provider, stream,
            ignore_working,
            ignore_stream,
            has_images="media" in kwargs,
        )

        if "proxy" not in kwargs:
            proxy = os.environ.get("G4F_PROXY")
            if proxy:
                kwargs["proxy"] = proxy
        if ignore_stream:
            kwargs["ignore_stream"] = True

        return model, provider, kwargs

    @staticmethod
    def create(model: Union[Model, str],
               messages: Messages,
               provider: Union[ProviderType, str, None] = None,
               stream: bool = False,
               image: ImageType = None,
               image_name: Optional[str] = None,
               ignore_working: bool = False,
               ignore_stream: bool = False,
               **kwargs) -> Union[CreateResult, str]:
        model, provider, kwargs = ChatCompletion._prepare_request(
            model, messages, provider, stream, image, image_name,
            ignore_working, ignore_stream, **kwargs
        )
        result = provider.create_function(model, messages, stream=stream, **kwargs)
        return result if stream or ignore_stream else concat_chunks(result)

    @staticmethod
    def create_async(model: Union[Model, str],
                     messages: Messages,
                     provider: Union[ProviderType, str, None] = None,
                     stream: bool = False,
                     image: ImageType = None,
                     image_name: Optional[str] = None,
                     ignore_working: bool = False,
                     ignore_stream: bool = False,
                     **kwargs) -> Union[AsyncResult, Coroutine[str]]:
        model, provider, kwargs = ChatCompletion._prepare_request(
            model, messages, provider, stream, image, image_name,
            ignore_working, ignore_stream, **kwargs
        )
        result = provider.async_create_function(model, messages, stream=stream, **kwargs)
        if not stream and not ignore_stream and hasattr(result, "__aiter__"):
            result = async_concat_chunks(result)
        return result