Mahynlo
Ready: unified env vars, deps, system packages, fixed app
5b715cc
import os
from typing import Any, Callable
from smolagents import LiteLLMModel
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from functools import lru_cache
import time
import re
from litellm import RateLimitError
class LocalTransformersModel:
def __init__(self, model_id: str, **kwargs):
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs)
self.pipeline = pipeline("text-generation", model=self.model, tokenizer=self.tokenizer)
def __call__(self, prompt: str, **kwargs):
outputs = self.pipeline(prompt, **kwargs)
return outputs[0]["generated_text"]
class WrapperLiteLLMModel(LiteLLMModel):
def __call__(self, messages, **kwargs):
max_retry = 5
for attempt in range(max_retry):
try:
return super().__call__(messages, **kwargs)
except RateLimitError as e:
print(f"RateLimitError (attempt {attempt+1}/{max_retry})")
# Try to extract retry time from the exception string
match = re.search(r'"retryDelay": ?"(\d+)s"', str(e))
retry_seconds = int(match.group(1)) if match else 50
print(f"Sleeping for {retry_seconds} seconds before retrying...")
time.sleep(retry_seconds)
raise RateLimitError(f"Rate limit exceeded after {max_retry} retries.")
@lru_cache(maxsize=1)
def get_lite_llm_model(model_id: str, **kwargs) -> WrapperLiteLLMModel:
"""
Returns a LiteLLM model instance.
Args:
model_id (str): The model identifier.
**kwargs: Additional keyword arguments for the model.
Returns:
LiteLLMModel: LiteLLM model instance.
"""
# Use the unified environment variable name GEMINI_API_KEY
return WrapperLiteLLMModel(model_id=model_id, api_key=os.getenv("GEMINI_API_KEY"), **kwargs)
@lru_cache(maxsize=1)
def get_local_model(model_id: str, **kwargs) -> LocalTransformersModel:
"""
Returns a Local Transformer model.
Args:
model_id (str): The model identifier.
**kwargs: Additional keyword arguments for the model.
Returns:
LocalTransformersModel: LiteLLM model instance.
"""
return LocalTransformersModel(model_id=model_id, **kwargs)
def get_model(model_type: str, model_id: str, **kwargs) -> Any:
"""
Returns a model instance based on the specified type.
Args:
model_type (str): The type of the model (e.g., 'HfApiModel').
model_id (str): The model identifier.
**kwargs: Additional keyword arguments for the model.
Returns:
Any: Model instance of the specified type.
"""
models: dict[str, Callable[..., Any]] = {
"LiteLLMModel": get_lite_llm_model,
"LocalTransformersModel": get_local_model,
}
if model_type not in models:
raise ValueError(f"Unknown model type: {model_type}")
return models[model_type](model_id, **kwargs)