Spaces:
Running
Running
| """ | |
| 🧠 LLM Client for CourseCrafter AI | |
| Multi-provider LLM client with streaming support. | |
| """ | |
| import json | |
| from typing import Dict, List, Any, Optional, AsyncGenerator | |
| from dataclasses import dataclass | |
| from abc import ABC, abstractmethod | |
| import os | |
| import openai | |
| import anthropic | |
| import google.generativeai as genai | |
| from ..types import LLMProvider, StreamChunk | |
| from ..utils.config import config | |
| class Message: | |
| """Standard message format""" | |
| role: str # "system", "user", "assistant" | |
| content: str | |
| class BaseLLMClient(ABC): | |
| """Abstract base class for LLM clients""" | |
| def __init__(self, provider: LLMProvider): | |
| self.provider = provider | |
| self.config = config.get_llm_config(provider) | |
| async def generate_stream(self, messages: List[Message]) -> AsyncGenerator[StreamChunk, None]: | |
| """Generate streaming response""" | |
| pass | |
| class OpenAIClient(BaseLLMClient): | |
| """OpenAI client with streaming support (works with OpenAI and compatible endpoints)""" | |
| def __init__(self, provider: LLMProvider = "openai"): | |
| super().__init__(provider) | |
| # Build client kwargs | |
| client_kwargs = { | |
| "api_key": self.config.api_key or "dummy", | |
| "timeout": self.config.timeout | |
| } | |
| # Add base_url for compatible endpoints | |
| if hasattr(self.config, 'base_url') and self.config.base_url: | |
| client_kwargs["base_url"] = self.config.base_url | |
| self.client = openai.AsyncOpenAI(**client_kwargs) | |
| def _format_messages(self, messages: List[Message]) -> List[Dict[str, Any]]: | |
| """Format messages for OpenAI""" | |
| return [{"role": msg.role, "content": msg.content} for msg in messages] | |
| async def generate_stream(self, messages: List[Message]) -> AsyncGenerator[StreamChunk, None]: | |
| """Generate streaming response from OpenAI""" | |
| formatted_messages = self._format_messages(messages) | |
| kwargs = { | |
| "model": self.config.model, | |
| "messages": formatted_messages, | |
| "temperature": self.config.temperature, | |
| "stream": True | |
| } | |
| if self.config.max_tokens: | |
| kwargs["max_tokens"] = self.config.max_tokens | |
| try: | |
| stream = await self.client.chat.completions.create(**kwargs) | |
| async for chunk in stream: | |
| if chunk.choices and chunk.choices[0].delta: | |
| delta = chunk.choices[0].delta | |
| if delta.content: | |
| yield StreamChunk( | |
| type="text", | |
| content=delta.content | |
| ) | |
| except Exception as e: | |
| yield StreamChunk( | |
| type="error", | |
| content=f"OpenAI API error: {str(e)}" | |
| ) | |
| class AnthropicClient(BaseLLMClient): | |
| """Anthropic client with streaming support""" | |
| def __init__(self): | |
| super().__init__("anthropic") | |
| self.client = anthropic.AsyncAnthropic( | |
| api_key=self.config.api_key, | |
| timeout=self.config.timeout | |
| ) | |
| def _format_messages(self, messages: List[Message]) -> tuple[List[Dict[str, Any]], Optional[str]]: | |
| """Format messages for Anthropic""" | |
| formatted = [] | |
| system_message = None | |
| for msg in messages: | |
| if msg.role == "system": | |
| system_message = msg.content | |
| elif msg.role in ["user", "assistant"]: | |
| formatted.append({ | |
| "role": msg.role, | |
| "content": msg.content | |
| }) | |
| return formatted, system_message | |
| async def generate_stream(self, messages: List[Message]) -> AsyncGenerator[StreamChunk, None]: | |
| """Generate streaming response from Anthropic""" | |
| formatted_messages, system_message = self._format_messages(messages) | |
| kwargs = { | |
| "model": self.config.model, | |
| "messages": formatted_messages, | |
| "temperature": self.config.temperature, | |
| "stream": True | |
| } | |
| if system_message: | |
| kwargs["system"] = system_message | |
| if self.config.max_tokens: | |
| kwargs["max_tokens"] = self.config.max_tokens | |
| try: | |
| stream = await self.client.messages.create(**kwargs) | |
| async for chunk in stream: | |
| if chunk.type == "content_block_delta": | |
| if hasattr(chunk.delta, 'text'): | |
| yield StreamChunk( | |
| type="text", | |
| content=chunk.delta.text | |
| ) | |
| except Exception as e: | |
| yield StreamChunk( | |
| type="error", | |
| content=f"Anthropic API error: {str(e)}" | |
| ) | |
| class GoogleClient(BaseLLMClient): | |
| """Google Gemini client with streaming support""" | |
| def __init__(self): | |
| super().__init__("google") | |
| genai.configure(api_key=self.config.api_key) | |
| self.model = genai.GenerativeModel(self.config.model) | |
| def _format_messages(self, messages: List[Message]) -> List[Dict[str, Any]]: | |
| """Format messages for Google""" | |
| formatted = [] | |
| for msg in messages: | |
| if msg.role == "system": | |
| # Google handles system messages differently | |
| formatted.append({ | |
| "role": "user", | |
| "parts": [{"text": f"System: {msg.content}"}] | |
| }) | |
| elif msg.role == "user": | |
| formatted.append({ | |
| "role": "user", | |
| "parts": [{"text": msg.content}] | |
| }) | |
| elif msg.role == "assistant": | |
| formatted.append({ | |
| "role": "model", | |
| "parts": [{"text": msg.content}] | |
| }) | |
| return formatted | |
| async def generate_stream(self, messages: List[Message]) -> AsyncGenerator[StreamChunk, None]: | |
| """Generate streaming response from Google""" | |
| formatted_messages = self._format_messages(messages) | |
| generation_config = { | |
| "temperature": self.config.temperature, | |
| } | |
| if self.config.max_tokens: | |
| generation_config["max_output_tokens"] = self.config.max_tokens | |
| try: | |
| response = await self.model.generate_content_async( | |
| formatted_messages, | |
| generation_config=generation_config, | |
| stream=True | |
| ) | |
| async for chunk in response: | |
| if chunk.text: | |
| yield StreamChunk( | |
| type="text", | |
| content=chunk.text | |
| ) | |
| except Exception as e: | |
| yield StreamChunk( | |
| type="error", | |
| content=f"Google API error: {str(e)}" | |
| ) | |
| class LlmClient: | |
| """ | |
| Unified LLM client that manages multiple providers | |
| """ | |
| def __init__(self): | |
| self.clients = {} | |
| self._initialize_clients() | |
| def _initialize_clients(self): | |
| """Initialize available LLM clients""" | |
| available_providers = config.get_available_llm_providers() | |
| for provider in available_providers: | |
| try: | |
| if provider in ["openai", "openai_compatible"]: | |
| self.clients[provider] = OpenAIClient(provider) | |
| elif provider == "anthropic": | |
| self.clients[provider] = AnthropicClient() | |
| elif provider == "google": | |
| self.clients[provider] = GoogleClient() | |
| print(f"✅ Initialized {provider} client") | |
| except Exception as e: | |
| print(f"❌ Failed to initialize {provider} client: {e}") | |
| def update_provider_config(self, provider: str, api_key: str = None, **kwargs): | |
| """Update configuration for a specific provider and reinitialize client""" | |
| # Update environment variables | |
| if provider == "openai" and api_key: | |
| os.environ["OPENAI_API_KEY"] = api_key | |
| elif provider == "anthropic" and api_key: | |
| os.environ["ANTHROPIC_API_KEY"] = api_key | |
| elif provider == "google" and api_key: | |
| os.environ["GOOGLE_API_KEY"] = api_key | |
| elif provider == "openai_compatible": | |
| if api_key: | |
| os.environ["OPENAI_COMPATIBLE_API_KEY"] = api_key | |
| if kwargs.get("base_url"): | |
| os.environ["OPENAI_COMPATIBLE_BASE_URL"] = kwargs["base_url"] | |
| if kwargs.get("model"): | |
| os.environ["OPENAI_COMPATIBLE_MODEL"] = kwargs["model"] | |
| # Reinitialize the specific client | |
| try: | |
| if provider in ["openai", "openai_compatible"]: | |
| self.clients[provider] = OpenAIClient(provider) | |
| elif provider == "anthropic": | |
| self.clients[provider] = AnthropicClient() | |
| elif provider == "google": | |
| self.clients[provider] = GoogleClient() | |
| print(f"✅ Updated and reinitialized {provider} client") | |
| return True | |
| except Exception as e: | |
| print(f"❌ Failed to reinitialize {provider} client: {e}") | |
| return False | |
| def get_available_providers(self) -> List[LLMProvider]: | |
| """Get list of available providers""" | |
| return list(self.clients.keys()) | |
| def get_client(self, provider: LLMProvider) -> BaseLLMClient: | |
| """Get client for specific provider""" | |
| if provider not in self.clients: | |
| raise ValueError(f"Provider {provider} not available") | |
| return self.clients[provider] | |
| async def generate_stream( | |
| self, | |
| provider: LLMProvider, | |
| messages: List[Message] | |
| ) -> AsyncGenerator[StreamChunk, None]: | |
| """Generate streaming response using specified provider""" | |
| client = self.get_client(provider) | |
| async for chunk in client.generate_stream(messages): | |
| yield chunk |