|
|
from __future__ import annotations |
|
|
|
|
|
from typing import Optional |
|
|
from functools import partial |
|
|
from dataclasses import dataclass, field |
|
|
|
|
|
from pydantic_ai.models import Model, KnownModelName, infer_model |
|
|
from pydantic_ai.models.openai import OpenAIModel, OpenAISystemPromptRole |
|
|
|
|
|
import pydantic_ai.models.openai |
|
|
pydantic_ai.models.openai.NOT_GIVEN = None |
|
|
|
|
|
from ..client import AsyncClient |
|
|
|
|
|
@dataclass(init=False) |
|
|
class AIModel(OpenAIModel): |
|
|
"""A model that uses the G4F API.""" |
|
|
|
|
|
client: AsyncClient = field(repr=False) |
|
|
system_prompt_role: OpenAISystemPromptRole | None = field(default=None) |
|
|
|
|
|
_model_name: str = field(repr=False) |
|
|
_provider: str = field(repr=False) |
|
|
_system: Optional[str] = field(repr=False) |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_name: str, |
|
|
provider: str | None = None, |
|
|
*, |
|
|
system_prompt_role: OpenAISystemPromptRole | None = None, |
|
|
system: str | None = 'openai', |
|
|
**kwargs |
|
|
): |
|
|
"""Initialize an AI model. |
|
|
|
|
|
Args: |
|
|
model_name: The name of the AI model to use. List of model names available |
|
|
[here](https://github.com/openai/openai-python/blob/v1.54.3/src/openai/types/chat_model.py#L7) |
|
|
(Unfortunately, despite being ask to do so, OpenAI do not provide `.inv` files for their API). |
|
|
system_prompt_role: The role to use for the system prompt message. If not provided, defaults to `'system'`. |
|
|
In the future, this may be inferred from the model name. |
|
|
system: The model provider used, defaults to `openai`. This is for observability purposes, you must |
|
|
customize the `base_url` and `api_key` to use a different provider. |
|
|
""" |
|
|
self._model_name = model_name |
|
|
self._provider = provider |
|
|
self.client = AsyncClient(provider=provider, **kwargs) |
|
|
self.system_prompt_role = system_prompt_role |
|
|
self._system = system |
|
|
|
|
|
def name(self) -> str: |
|
|
if self._provider: |
|
|
return f'g4f:{self._provider}:{self._model_name}' |
|
|
return f'g4f:{self._model_name}' |
|
|
|
|
|
def new_infer_model(model: Model | KnownModelName, api_key: str = None) -> Model: |
|
|
if isinstance(model, Model): |
|
|
return model |
|
|
if model.startswith("g4f:"): |
|
|
model = model[4:] |
|
|
if ":" in model: |
|
|
provider, model = model.split(":", 1) |
|
|
return AIModel(model, provider=provider, api_key=api_key) |
|
|
return AIModel(model) |
|
|
return infer_model(model) |
|
|
|
|
|
def patch_infer_model(api_key: str | None = None): |
|
|
import pydantic_ai.models |
|
|
|
|
|
pydantic_ai.models.infer_model = partial(new_infer_model, api_key=api_key) |
|
|
pydantic_ai.models.AIModel = AIModel |