File size: 3,039 Bytes
4b5bc79
5b715cc
 
 
 
 
4b5bc79
 
5b715cc
4b5bc79
7038b94
5b715cc
 
 
 
 
4b5bc79
5b715cc
 
 
4b5bc79
5b715cc
 
 
 
 
 
 
 
7038b94
5b715cc
 
 
7038b94
5b715cc
 
7038b94
5b715cc
7038b94
5b715cc
 
 
 
7038b94
5b715cc
 
 
 
 
 
4b5bc79
5b715cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b5bc79
5b715cc
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
from typing import Any, Callable

from smolagents import LiteLLMModel
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from functools import lru_cache
import time
import re
from litellm import RateLimitError


class LocalTransformersModel:
    def __init__(self, model_id: str, **kwargs):
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        self.model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs)
        self.pipeline = pipeline("text-generation", model=self.model, tokenizer=self.tokenizer)

    def __call__(self, prompt: str, **kwargs):
        outputs = self.pipeline(prompt, **kwargs)
        return outputs[0]["generated_text"]

class WrapperLiteLLMModel(LiteLLMModel):
    def __call__(self, messages, **kwargs):
        max_retry = 5
        for attempt in range(max_retry):
            try:
                return super().__call__(messages, **kwargs)
            except RateLimitError as e:
                print(f"RateLimitError (attempt {attempt+1}/{max_retry})")

                # Try to extract retry time from the exception string
                match = re.search(r'"retryDelay": ?"(\d+)s"', str(e))
                retry_seconds = int(match.group(1)) if match else 50

                print(f"Sleeping for {retry_seconds} seconds before retrying...")
                time.sleep(retry_seconds)

        raise RateLimitError(f"Rate limit exceeded after {max_retry} retries.")

@lru_cache(maxsize=1)
def get_lite_llm_model(model_id: str,  **kwargs) -> WrapperLiteLLMModel:
    """
    Returns a LiteLLM model instance.

    Args:
        model_id (str): The model identifier.
        **kwargs: Additional keyword arguments for the model.

    Returns:
        LiteLLMModel: LiteLLM model instance.
    """
    # Use the unified environment variable name GEMINI_API_KEY
    return WrapperLiteLLMModel(model_id=model_id, api_key=os.getenv("GEMINI_API_KEY"), **kwargs)


@lru_cache(maxsize=1)
def get_local_model(model_id: str, **kwargs) -> LocalTransformersModel:
    """
    Returns a Local Transformer model.

    Args:
        model_id (str): The model identifier.
        **kwargs: Additional keyword arguments for the model.

    Returns:
        LocalTransformersModel: LiteLLM model instance.
    """
    return LocalTransformersModel(model_id=model_id, **kwargs)


def get_model(model_type: str, model_id: str, **kwargs) -> Any:
    """
    Returns a model instance based on the specified type.

    Args:
        model_type (str): The type of the model (e.g., 'HfApiModel').
        model_id (str): The model identifier.
        **kwargs: Additional keyword arguments for the model.

    Returns:
        Any: Model instance of the specified type.
    """
    models: dict[str, Callable[..., Any]] = {
        "LiteLLMModel": get_lite_llm_model,
        "LocalTransformersModel": get_local_model,
    }

    if model_type not in models:
        raise ValueError(f"Unknown model type: {model_type}")

    return models[model_type](model_id, **kwargs)