|
|
import os |
|
|
from typing import Any, Callable |
|
|
|
|
|
from smolagents import LiteLLMModel |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
|
from functools import lru_cache |
|
|
import time |
|
|
import re |
|
|
from litellm import RateLimitError |
|
|
|
|
|
|
|
|
class LocalTransformersModel: |
|
|
def __init__(self, model_id: str, **kwargs): |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
self.model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs) |
|
|
self.pipeline = pipeline("text-generation", model=self.model, tokenizer=self.tokenizer) |
|
|
|
|
|
def __call__(self, prompt: str, **kwargs): |
|
|
outputs = self.pipeline(prompt, **kwargs) |
|
|
return outputs[0]["generated_text"] |
|
|
|
|
|
class WrapperLiteLLMModel(LiteLLMModel): |
|
|
def __call__(self, messages, **kwargs): |
|
|
max_retry = 5 |
|
|
for attempt in range(max_retry): |
|
|
try: |
|
|
return super().__call__(messages, **kwargs) |
|
|
except RateLimitError as e: |
|
|
print(f"RateLimitError (attempt {attempt+1}/{max_retry})") |
|
|
|
|
|
|
|
|
match = re.search(r'"retryDelay": ?"(\d+)s"', str(e)) |
|
|
retry_seconds = int(match.group(1)) if match else 50 |
|
|
|
|
|
print(f"Sleeping for {retry_seconds} seconds before retrying...") |
|
|
time.sleep(retry_seconds) |
|
|
|
|
|
raise RateLimitError(f"Rate limit exceeded after {max_retry} retries.") |
|
|
|
|
|
@lru_cache(maxsize=1) |
|
|
def get_lite_llm_model(model_id: str, **kwargs) -> WrapperLiteLLMModel: |
|
|
""" |
|
|
Returns a LiteLLM model instance. |
|
|
|
|
|
Args: |
|
|
model_id (str): The model identifier. |
|
|
**kwargs: Additional keyword arguments for the model. |
|
|
|
|
|
Returns: |
|
|
LiteLLMModel: LiteLLM model instance. |
|
|
""" |
|
|
|
|
|
return WrapperLiteLLMModel(model_id=model_id, api_key=os.getenv("GEMINI_API_KEY"), **kwargs) |
|
|
|
|
|
|
|
|
@lru_cache(maxsize=1) |
|
|
def get_local_model(model_id: str, **kwargs) -> LocalTransformersModel: |
|
|
""" |
|
|
Returns a Local Transformer model. |
|
|
|
|
|
Args: |
|
|
model_id (str): The model identifier. |
|
|
**kwargs: Additional keyword arguments for the model. |
|
|
|
|
|
Returns: |
|
|
LocalTransformersModel: LiteLLM model instance. |
|
|
""" |
|
|
return LocalTransformersModel(model_id=model_id, **kwargs) |
|
|
|
|
|
|
|
|
def get_model(model_type: str, model_id: str, **kwargs) -> Any: |
|
|
""" |
|
|
Returns a model instance based on the specified type. |
|
|
|
|
|
Args: |
|
|
model_type (str): The type of the model (e.g., 'HfApiModel'). |
|
|
model_id (str): The model identifier. |
|
|
**kwargs: Additional keyword arguments for the model. |
|
|
|
|
|
Returns: |
|
|
Any: Model instance of the specified type. |
|
|
""" |
|
|
models: dict[str, Callable[..., Any]] = { |
|
|
"LiteLLMModel": get_lite_llm_model, |
|
|
"LocalTransformersModel": get_local_model, |
|
|
} |
|
|
|
|
|
if model_type not in models: |
|
|
raise ValueError(f"Unknown model type: {model_type}") |
|
|
|
|
|
return models[model_type](model_id, **kwargs) |
|
|
|