File size: 2,318 Bytes
a4b70d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from __future__ import annotations

from ..models import ModelUtils, ImageModel, VisionModel
from ..Provider import ProviderUtils
from ..providers.types import ProviderType

class ClientModels():
    def __init__(self, client, provider: ProviderType = None, media_provider: ProviderType = None):
        self.client = client
        self.provider = provider
        self.media_provider = media_provider

    def get(self, name, default=None) -> ProviderType:
        if name in ModelUtils.convert:
            return ModelUtils.convert[name].best_provider
        if name in ProviderUtils.convert:
            return ProviderUtils.convert[name]
        return default

    def get_all(self, api_key: str = None, **kwargs) -> list[str]:
        if self.provider is None:
            return []
        if api_key is None:
            api_key = self.client.api_key
        return self.provider.get_models(
            **kwargs,
            **{} if api_key is None else {"api_key": api_key}
        )

    def get_vision(self, **kwargs) -> list[str]:
        if self.provider is None:
            return [model_id for model_id, model in ModelUtils.convert.items() if isinstance(model, VisionModel)]
        self.get_all(**kwargs)
        if hasattr(self.provider, "vision_models"):
            return self.provider.vision_models
        return []

    def get_media(self, api_key: str = None, **kwargs) -> list[str]:
        if self.media_provider is None:
            return []
        if api_key is None:
            api_key = self.client.api_key
        return self.media_provider.get_models(
            **kwargs,
            **{} if api_key is None else {"api_key": api_key}
        )

    def get_image(self, **kwargs) -> list[str]:
        if self.media_provider is None:
            return [model_id for model_id, model in ModelUtils.convert.items() if isinstance(model, ImageModel)]
        self.get_media(**kwargs)
        if hasattr(self.media_provider, "image_models"):
            return self.media_provider.image_models
        return []

    def get_video(self, **kwargs) -> list[str]:
        if self.media_provider is None:
            return []
        self.get_media(**kwargs)
        if hasattr(self.media_provider, "video_models"):
            return self.media_provider.video_models
        return []