Spaces:
Paused
Paused
File size: 8,365 Bytes
4efde5d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 |
from typing import Optional, List, Dict, Any, Tuple
from .registry import registry
from .models import Model, ModelCapability
from utils.logger import logger
from .registry import DEFAULT_PREMIUM_MODEL, DEFAULT_FREE_MODEL
class ModelManager:
def __init__(self):
self.registry = registry
def get_model(self, model_id: str) -> Optional[Model]:
return self.registry.get(model_id)
def resolve_model_id(self, model_id: str) -> str:
logger.debug(f"resolve_model_id called with: '{model_id}' (type: {type(model_id)})")
resolved = self.registry.resolve_model_id(model_id)
if resolved:
logger.debug(f"Resolved model '{model_id}' to '{resolved}'")
return resolved
all_aliases = list(self.registry._aliases.keys())
logger.warning(f"Could not resolve model ID: '{model_id}'. Available aliases: {all_aliases[:10]}...")
return model_id
def validate_model(self, model_id: str) -> Tuple[bool, str]:
model = self.get_model(model_id)
if not model:
return False, f"Model '{model_id}' not found"
if not model.enabled:
return False, f"Model '{model.name}' is currently disabled"
return True, ""
def calculate_cost(
self,
model_id: str,
input_tokens: int,
output_tokens: int
) -> Optional[float]:
model = self.get_model(model_id)
if not model or not model.pricing:
logger.warning(f"No pricing available for model: {model_id}")
return None
input_cost = input_tokens * model.pricing.input_cost_per_token
output_cost = output_tokens * model.pricing.output_cost_per_token
total_cost = input_cost + output_cost
logger.debug(
f"Cost calculation for {model.name}: "
f"{input_tokens} input tokens (${input_cost:.6f}) + "
f"{output_tokens} output tokens (${output_cost:.6f}) = "
f"${total_cost:.6f}"
)
return total_cost
def get_models_for_tier(self, tier: str) -> List[Model]:
return self.registry.get_by_tier(tier, enabled_only=True)
def get_models_with_capability(self, capability: ModelCapability) -> List[Model]:
return self.registry.get_by_capability(capability, enabled_only=True)
def select_best_model(
self,
tier: str,
required_capabilities: Optional[List[ModelCapability]] = None,
min_context_window: Optional[int] = None,
prefer_cheaper: bool = False
) -> Optional[Model]:
models = self.get_models_for_tier(tier)
if required_capabilities:
models = [
m for m in models
if all(cap in m.capabilities for cap in required_capabilities)
]
if min_context_window:
models = [m for m in models if m.context_window >= min_context_window]
if not models:
return None
if prefer_cheaper and any(m.pricing for m in models):
models_with_pricing = [m for m in models if m.pricing]
if models_with_pricing:
models = sorted(
models_with_pricing,
key=lambda m: m.pricing.input_cost_per_million_tokens
)
else:
models = sorted(
models,
key=lambda m: (-m.priority, not m.recommended)
)
return models[0] if models else None
def get_default_model(self, tier: str = "free") -> Optional[Model]:
models = self.get_models_for_tier(tier)
recommended = [m for m in models if m.recommended]
if recommended:
recommended = sorted(recommended, key=lambda m: -m.priority)
return recommended[0]
if models:
models = sorted(models, key=lambda m: -m.priority)
return models[0]
return None
def get_context_window(self, model_id: str, default: int = 31_000) -> int:
return self.registry.get_context_window(model_id, default)
def check_token_limit(
self,
model_id: str,
token_count: int,
is_input: bool = True
) -> Tuple[bool, int]:
model = self.get_model(model_id)
if not model:
return False, 0
if is_input:
max_allowed = model.context_window
else:
max_allowed = model.max_output_tokens or model.context_window
return token_count <= max_allowed, max_allowed
def format_model_info(self, model_id: str) -> Dict[str, Any]:
model = self.get_model(model_id)
if not model:
return {"error": f"Model '{model_id}' not found"}
return {
"id": model.id,
"name": model.name,
"provider": model.provider.value,
"context_window": model.context_window,
"max_output_tokens": model.max_output_tokens,
"capabilities": [cap.value for cap in model.capabilities],
"pricing": {
"input_per_million": model.pricing.input_cost_per_million_tokens,
"output_per_million": model.pricing.output_cost_per_million_tokens,
} if model.pricing else None,
"enabled": model.enabled,
"beta": model.beta,
"tier_availability": model.tier_availability,
"priority": model.priority,
"recommended": model.recommended,
}
def list_available_models(
self,
tier: Optional[str] = None,
include_disabled: bool = False
) -> List[Dict[str, Any]]:
logger.debug(f"list_available_models called with tier='{tier}', include_disabled={include_disabled}")
if tier:
models = self.registry.get_by_tier(tier, enabled_only=not include_disabled)
logger.debug(f"Found {len(models)} models for tier '{tier}'")
else:
models = self.registry.get_all(enabled_only=not include_disabled)
logger.debug(f"Found {len(models)} total models")
if models:
model_names = [m.name for m in models]
logger.debug(f"Models: {model_names}")
else:
logger.warning(f"No models found for tier '{tier}' - this might indicate a configuration issue")
models = sorted(
models,
key=lambda m: (not m.is_free_tier, -m.priority, m.name)
)
return [self.format_model_info(m.id) for m in models]
def get_legacy_constants(self) -> Dict:
return self.registry.to_legacy_format()
async def get_default_model_for_user(self, client, user_id: str) -> str:
try:
from utils.config import config, EnvMode
if config.ENV_MODE == EnvMode.LOCAL:
return DEFAULT_PREMIUM_MODEL
from services.billing import get_user_subscription, SUBSCRIPTION_TIERS
subscription = await get_user_subscription(user_id)
is_paid_tier = False
if subscription:
price_id = None
if subscription.get('items') and subscription['items'].get('data') and len(subscription['items']['data']) > 0:
price_id = subscription['items']['data'][0]['price']['id']
else:
price_id = subscription.get('price_id')
tier_info = SUBSCRIPTION_TIERS.get(price_id)
if tier_info and tier_info['name'] != 'free':
is_paid_tier = True
if is_paid_tier:
logger.debug(f"Setting Claude Sonnet 4 as default for paid user {user_id}")
return DEFAULT_PREMIUM_MODEL
else:
logger.debug(f"Setting Kimi K2 as default for free user {user_id}")
return DEFAULT_FREE_MODEL
except Exception as e:
logger.warning(f"Failed to determine user tier for {user_id}: {e}")
return DEFAULT_FREE_MODEL
model_manager = ModelManager() |