File size: 8,365 Bytes
4efde5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
from typing import Optional, List, Dict, Any, Tuple
from .registry import registry
from .models import Model, ModelCapability
from utils.logger import logger
from .registry import DEFAULT_PREMIUM_MODEL, DEFAULT_FREE_MODEL

class ModelManager:
    def __init__(self):
        self.registry = registry
    
    def get_model(self, model_id: str) -> Optional[Model]:
        return self.registry.get(model_id)
    
    def resolve_model_id(self, model_id: str) -> str:
        logger.debug(f"resolve_model_id called with: '{model_id}' (type: {type(model_id)})")
        
        resolved = self.registry.resolve_model_id(model_id)
        if resolved:
            logger.debug(f"Resolved model '{model_id}' to '{resolved}'")
            return resolved
        
        all_aliases = list(self.registry._aliases.keys())
        logger.warning(f"Could not resolve model ID: '{model_id}'. Available aliases: {all_aliases[:10]}...")
        return model_id
    
    def validate_model(self, model_id: str) -> Tuple[bool, str]:
        model = self.get_model(model_id)
        
        if not model:
            return False, f"Model '{model_id}' not found"
        
        if not model.enabled:
            return False, f"Model '{model.name}' is currently disabled"
        
        return True, ""
    
    def calculate_cost(
        self,
        model_id: str,
        input_tokens: int,
        output_tokens: int
    ) -> Optional[float]:
        model = self.get_model(model_id)
        if not model or not model.pricing:
            logger.warning(f"No pricing available for model: {model_id}")
            return None
        
        input_cost = input_tokens * model.pricing.input_cost_per_token
        output_cost = output_tokens * model.pricing.output_cost_per_token
        total_cost = input_cost + output_cost
        
        logger.debug(
            f"Cost calculation for {model.name}: "
            f"{input_tokens} input tokens (${input_cost:.6f}) + "
            f"{output_tokens} output tokens (${output_cost:.6f}) = "
            f"${total_cost:.6f}"
        )
        
        return total_cost
    
    def get_models_for_tier(self, tier: str) -> List[Model]:
        return self.registry.get_by_tier(tier, enabled_only=True)
    
    def get_models_with_capability(self, capability: ModelCapability) -> List[Model]:
        return self.registry.get_by_capability(capability, enabled_only=True)
    
    def select_best_model(
        self,
        tier: str,
        required_capabilities: Optional[List[ModelCapability]] = None,
        min_context_window: Optional[int] = None,
        prefer_cheaper: bool = False
    ) -> Optional[Model]:
        models = self.get_models_for_tier(tier)
        
        if required_capabilities:
            models = [
                m for m in models
                if all(cap in m.capabilities for cap in required_capabilities)
            ]
        
        if min_context_window:
            models = [m for m in models if m.context_window >= min_context_window]
        
        if not models:
            return None
        
        if prefer_cheaper and any(m.pricing for m in models):
            models_with_pricing = [m for m in models if m.pricing]
            if models_with_pricing:
                models = sorted(
                    models_with_pricing,
                    key=lambda m: m.pricing.input_cost_per_million_tokens
                )
        else:
            models = sorted(
                models,
                key=lambda m: (-m.priority, not m.recommended)
            )
        
        return models[0] if models else None
    
    def get_default_model(self, tier: str = "free") -> Optional[Model]:
        models = self.get_models_for_tier(tier)
        
        recommended = [m for m in models if m.recommended]
        if recommended:
            recommended = sorted(recommended, key=lambda m: -m.priority)
            return recommended[0]
        
        if models:
            models = sorted(models, key=lambda m: -m.priority)
            return models[0]
        
        return None
    
    def get_context_window(self, model_id: str, default: int = 31_000) -> int:
        return self.registry.get_context_window(model_id, default)
    
    def check_token_limit(
        self,
        model_id: str,
        token_count: int,
        is_input: bool = True
    ) -> Tuple[bool, int]:
        model = self.get_model(model_id)
        if not model:
            return False, 0
        
        if is_input:
            max_allowed = model.context_window
        else:
            max_allowed = model.max_output_tokens or model.context_window
        
        return token_count <= max_allowed, max_allowed
    
    def format_model_info(self, model_id: str) -> Dict[str, Any]:
        model = self.get_model(model_id)
        if not model:
            return {"error": f"Model '{model_id}' not found"}
        
        return {
            "id": model.id,
            "name": model.name,
            "provider": model.provider.value,
            "context_window": model.context_window,
            "max_output_tokens": model.max_output_tokens,
            "capabilities": [cap.value for cap in model.capabilities],
            "pricing": {
                "input_per_million": model.pricing.input_cost_per_million_tokens,
                "output_per_million": model.pricing.output_cost_per_million_tokens,
            } if model.pricing else None,
            "enabled": model.enabled,
            "beta": model.beta,
            "tier_availability": model.tier_availability,
            "priority": model.priority,
            "recommended": model.recommended,
        }
    
    def list_available_models(
        self,
        tier: Optional[str] = None,
        include_disabled: bool = False
    ) -> List[Dict[str, Any]]:
        logger.debug(f"list_available_models called with tier='{tier}', include_disabled={include_disabled}")
        
        if tier:
            models = self.registry.get_by_tier(tier, enabled_only=not include_disabled)
            logger.debug(f"Found {len(models)} models for tier '{tier}'")
        else:
            models = self.registry.get_all(enabled_only=not include_disabled)
            logger.debug(f"Found {len(models)} total models")
        
        if models:
            model_names = [m.name for m in models]
            logger.debug(f"Models: {model_names}")
        else:
            logger.warning(f"No models found for tier '{tier}' - this might indicate a configuration issue")
        
        models = sorted(
            models,
            key=lambda m: (not m.is_free_tier, -m.priority, m.name)
        )
        
        return [self.format_model_info(m.id) for m in models]
    
    def get_legacy_constants(self) -> Dict:
        return self.registry.to_legacy_format()
    
    async def get_default_model_for_user(self, client, user_id: str) -> str:
        try:
            from utils.config import config, EnvMode
            if config.ENV_MODE == EnvMode.LOCAL:
                return DEFAULT_PREMIUM_MODEL
                
            from services.billing import get_user_subscription, SUBSCRIPTION_TIERS
            
            subscription = await get_user_subscription(user_id)
            
            is_paid_tier = False
            if subscription:
                price_id = None
                if subscription.get('items') and subscription['items'].get('data') and len(subscription['items']['data']) > 0:
                    price_id = subscription['items']['data'][0]['price']['id']
                else:
                    price_id = subscription.get('price_id')
                
                tier_info = SUBSCRIPTION_TIERS.get(price_id)
                if tier_info and tier_info['name'] != 'free':
                    is_paid_tier = True
            
            if is_paid_tier:
                logger.debug(f"Setting Claude Sonnet 4 as default for paid user {user_id}")
                return DEFAULT_PREMIUM_MODEL
            else:
                logger.debug(f"Setting Kimi K2 as default for free user {user_id}")
                return DEFAULT_FREE_MODEL
                
        except Exception as e:
            logger.warning(f"Failed to determine user tier for {user_id}: {e}")
            return DEFAULT_FREE_MODEL


model_manager = ModelManager()