Mahynlo
Ready: unified env vars, deps, system packages, fixed app
5b715cc
raw
history blame
3.04 kB
import os
from typing import Any, Callable
from smolagents import LiteLLMModel
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from functools import lru_cache
import time
import re
from litellm import RateLimitError
class LocalTransformersModel:
def __init__(self, model_id: str, **kwargs):
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs)
self.pipeline = pipeline("text-generation", model=self.model, tokenizer=self.tokenizer)
def __call__(self, prompt: str, **kwargs):
outputs = self.pipeline(prompt, **kwargs)
return outputs[0]["generated_text"]
class WrapperLiteLLMModel(LiteLLMModel):
def __call__(self, messages, **kwargs):
max_retry = 5
for attempt in range(max_retry):
try:
return super().__call__(messages, **kwargs)
except RateLimitError as e:
print(f"RateLimitError (attempt {attempt+1}/{max_retry})")
# Try to extract retry time from the exception string
match = re.search(r'"retryDelay": ?"(\d+)s"', str(e))
retry_seconds = int(match.group(1)) if match else 50
print(f"Sleeping for {retry_seconds} seconds before retrying...")
time.sleep(retry_seconds)
raise RateLimitError(f"Rate limit exceeded after {max_retry} retries.")
@lru_cache(maxsize=1)
def get_lite_llm_model(model_id: str, **kwargs) -> WrapperLiteLLMModel:
"""
Returns a LiteLLM model instance.
Args:
model_id (str): The model identifier.
**kwargs: Additional keyword arguments for the model.
Returns:
LiteLLMModel: LiteLLM model instance.
"""
# Use the unified environment variable name GEMINI_API_KEY
return WrapperLiteLLMModel(model_id=model_id, api_key=os.getenv("GEMINI_API_KEY"), **kwargs)
@lru_cache(maxsize=1)
def get_local_model(model_id: str, **kwargs) -> LocalTransformersModel:
"""
Returns a Local Transformer model.
Args:
model_id (str): The model identifier.
**kwargs: Additional keyword arguments for the model.
Returns:
LocalTransformersModel: LiteLLM model instance.
"""
return LocalTransformersModel(model_id=model_id, **kwargs)
def get_model(model_type: str, model_id: str, **kwargs) -> Any:
"""
Returns a model instance based on the specified type.
Args:
model_type (str): The type of the model (e.g., 'HfApiModel').
model_id (str): The model identifier.
**kwargs: Additional keyword arguments for the model.
Returns:
Any: Model instance of the specified type.
"""
models: dict[str, Callable[..., Any]] = {
"LiteLLMModel": get_lite_llm_model,
"LocalTransformersModel": get_local_model,
}
if model_type not in models:
raise ValueError(f"Unknown model type: {model_type}")
return models[model_type](model_id, **kwargs)