Spaces:
Runtime error
Runtime error
| """Groq API integration with streaming and optimizations.""" | |
| import os | |
| import logging | |
| import asyncio | |
| from typing import Dict, Any, Optional, List, AsyncGenerator, Union | |
| import groq | |
| from datetime import datetime | |
| import json | |
| from dataclasses import dataclass | |
| from concurrent.futures import ThreadPoolExecutor | |
| from .base import ReasoningStrategy, StrategyResult | |
| logger = logging.getLogger(__name__) | |
| class GroqConfig: | |
| """Configuration for Groq models.""" | |
| model_name: str | |
| max_tokens: int | |
| temperature: float | |
| top_p: float | |
| top_k: Optional[int] = None | |
| presence_penalty: float = 0.0 | |
| frequency_penalty: float = 0.0 | |
| stop_sequences: Optional[List[str]] = None | |
| chunk_size: int = 1024 | |
| retry_attempts: int = 3 | |
| retry_delay: float = 1.0 | |
| class GroqStrategy(ReasoningStrategy): | |
| """Enhanced reasoning strategy using Groq's API with streaming and optimizations.""" | |
| def __init__(self, api_key: Optional[str] = None): | |
| """Initialize Groq strategy.""" | |
| super().__init__() | |
| self.api_key = api_key or os.getenv("GROQ_API_KEY") | |
| if not self.api_key: | |
| raise ValueError("GROQ_API_KEY must be set") | |
| # Initialize Groq client with optimized settings | |
| self.client = groq.Groq( | |
| api_key=self.api_key, | |
| timeout=30, | |
| max_retries=3 | |
| ) | |
| # Optimized model configurations | |
| self.model_configs = { | |
| "mixtral": GroqConfig( | |
| model_name="mixtral-8x7b-32768", | |
| max_tokens=32768, | |
| temperature=0.7, | |
| top_p=0.9, | |
| top_k=40, | |
| presence_penalty=0.1, | |
| frequency_penalty=0.1, | |
| chunk_size=4096 | |
| ), | |
| "llama": GroqConfig( | |
| model_name="llama2-70b-4096", | |
| max_tokens=4096, | |
| temperature=0.8, | |
| top_p=0.9, | |
| top_k=50, | |
| presence_penalty=0.2, | |
| frequency_penalty=0.2, | |
| chunk_size=1024 | |
| ) | |
| } | |
| # Initialize thread pool for parallel processing | |
| self.executor = ThreadPoolExecutor(max_workers=4) | |
| # Response cache | |
| self.cache: Dict[str, Any] = {} | |
| self.cache_ttl = 3600 # 1 hour | |
| async def reason_stream( | |
| self, | |
| query: str, | |
| context: Dict[str, Any], | |
| model: str = "mixtral", | |
| chunk_handler: Optional[callable] = None | |
| ) -> AsyncGenerator[str, None]: | |
| """ | |
| Stream reasoning results from Groq's API. | |
| Args: | |
| query: The query to reason about | |
| context: Additional context | |
| model: Model to use ('mixtral' or 'llama') | |
| chunk_handler: Optional callback for handling chunks | |
| """ | |
| config = self.model_configs[model] | |
| messages = self._prepare_messages(query, context) | |
| try: | |
| stream = await self.client.chat.completions.create( | |
| model=config.model_name, | |
| messages=messages, | |
| temperature=config.temperature, | |
| top_p=config.top_p, | |
| top_k=config.top_k, | |
| presence_penalty=config.presence_penalty, | |
| frequency_penalty=config.frequency_penalty, | |
| max_tokens=config.max_tokens, | |
| stream=True | |
| ) | |
| collected_content = [] | |
| async for chunk in stream: | |
| if chunk.choices[0].delta.content: | |
| content = chunk.choices[0].delta.content | |
| collected_content.append(content) | |
| if chunk_handler: | |
| await chunk_handler(content) | |
| yield content | |
| # Cache the complete response | |
| cache_key = self._generate_cache_key(query, context, model) | |
| self.cache[cache_key] = { | |
| "content": "".join(collected_content), | |
| "timestamp": datetime.now() | |
| } | |
| except Exception as e: | |
| logger.error(f"Groq streaming error: {str(e)}") | |
| yield f"Error: {str(e)}" | |
| async def reason( | |
| self, | |
| query: str, | |
| context: Dict[str, Any], | |
| model: str = "mixtral" | |
| ) -> StrategyResult: | |
| """ | |
| Enhanced reasoning with Groq's API including optimizations. | |
| Args: | |
| query: The query to reason about | |
| context: Additional context | |
| model: Model to use ('mixtral' or 'llama') | |
| """ | |
| # Check cache first | |
| cache_key = self._generate_cache_key(query, context, model) | |
| cached_response = self._get_from_cache(cache_key) | |
| if cached_response: | |
| return self._create_result(cached_response, model, from_cache=True) | |
| config = self.model_configs[model] | |
| messages = self._prepare_messages(query, context) | |
| # Implement retry logic with exponential backoff | |
| for attempt in range(config.retry_attempts): | |
| try: | |
| start_time = datetime.now() | |
| # Make API call with optimized parameters | |
| response = await self.client.chat.completions.create( | |
| model=config.model_name, | |
| messages=messages, | |
| temperature=config.temperature, | |
| top_p=config.top_p, | |
| top_k=config.top_k, | |
| presence_penalty=config.presence_penalty, | |
| frequency_penalty=config.frequency_penalty, | |
| max_tokens=config.max_tokens, | |
| stream=False | |
| ) | |
| end_time = datetime.now() | |
| # Cache successful response | |
| self.cache[cache_key] = { | |
| "content": response.choices[0].message.content, | |
| "timestamp": datetime.now() | |
| } | |
| return self._create_result(response, model) | |
| except Exception as e: | |
| delay = config.retry_delay * (2 ** attempt) | |
| logger.warning(f"Groq API attempt {attempt + 1} failed: {str(e)}") | |
| if attempt < config.retry_attempts - 1: | |
| await asyncio.sleep(delay) | |
| else: | |
| logger.error(f"All Groq API attempts failed: {str(e)}") | |
| return self._create_error_result(str(e)) | |
| def _create_result( | |
| self, | |
| response: Union[Dict, Any], | |
| model: str, | |
| from_cache: bool = False | |
| ) -> StrategyResult: | |
| """Create a strategy result from response.""" | |
| if from_cache: | |
| answer = response["content"] | |
| confidence = 0.9 # Higher confidence for cached responses | |
| performance_metrics = { | |
| "from_cache": True, | |
| "cache_age": (datetime.now() - response["timestamp"]).total_seconds() | |
| } | |
| else: | |
| answer = response.choices[0].message.content | |
| confidence = self._calculate_confidence(response) | |
| performance_metrics = { | |
| "latency": response.usage.total_tokens / 1000, # tokens per second | |
| "tokens_used": response.usage.total_tokens, | |
| "prompt_tokens": response.usage.prompt_tokens, | |
| "completion_tokens": response.usage.completion_tokens, | |
| "model": self.model_configs[model].model_name | |
| } | |
| return StrategyResult( | |
| strategy_type="groq", | |
| success=True, | |
| answer=answer, | |
| confidence=confidence, | |
| reasoning_trace=[{ | |
| "step": "groq_api_call", | |
| "model": self.model_configs[model].model_name, | |
| "timestamp": datetime.now().isoformat(), | |
| "metrics": performance_metrics | |
| }], | |
| metadata={ | |
| "model": self.model_configs[model].model_name, | |
| "from_cache": from_cache | |
| }, | |
| performance_metrics=performance_metrics | |
| ) | |
| def _create_error_result(self, error: str) -> StrategyResult: | |
| """Create an error result.""" | |
| return StrategyResult( | |
| strategy_type="groq", | |
| success=False, | |
| answer=None, | |
| confidence=0.0, | |
| reasoning_trace=[{ | |
| "step": "groq_api_error", | |
| "error": error, | |
| "timestamp": datetime.now().isoformat() | |
| }], | |
| metadata={"error": error}, | |
| performance_metrics={} | |
| ) | |
| def _generate_cache_key( | |
| self, | |
| query: str, | |
| context: Dict[str, Any], | |
| model: str | |
| ) -> str: | |
| """Generate a cache key.""" | |
| key_data = { | |
| "query": query, | |
| "context": context, | |
| "model": model | |
| } | |
| return json.dumps(key_data, sort_keys=True) | |
| def _get_from_cache(self, cache_key: str) -> Optional[Dict]: | |
| """Get response from cache if valid.""" | |
| if cache_key in self.cache: | |
| cached = self.cache[cache_key] | |
| age = (datetime.now() - cached["timestamp"]).total_seconds() | |
| if age < self.cache_ttl: | |
| return cached | |
| else: | |
| del self.cache[cache_key] | |
| return None | |
| def _calculate_confidence(self, response: Any) -> float: | |
| """Calculate confidence score from response.""" | |
| confidence = 0.8 # Base confidence | |
| # Adjust based on token usage and model behavior | |
| if hasattr(response, 'usage'): | |
| completion_tokens = response.usage.completion_tokens | |
| total_tokens = response.usage.total_tokens | |
| # Length-based adjustment | |
| if completion_tokens < 10: | |
| confidence *= 0.8 # Reduce confidence for very short responses | |
| elif completion_tokens > 100: | |
| confidence *= 1.1 # Increase confidence for detailed responses | |
| # Token efficiency adjustment | |
| token_efficiency = completion_tokens / total_tokens | |
| if token_efficiency > 0.5: | |
| confidence *= 1.1 # Good token efficiency | |
| # Response completeness check | |
| if hasattr(response.choices[0], 'finish_reason'): | |
| if response.choices[0].finish_reason == "stop": | |
| confidence *= 1.1 # Natural completion | |
| elif response.choices[0].finish_reason == "length": | |
| confidence *= 0.9 # Truncated response | |
| return min(1.0, max(0.0, confidence)) # Ensure between 0 and 1 | |
| def _prepare_messages( | |
| self, | |
| query: str, | |
| context: Dict[str, Any] | |
| ) -> List[Dict[str, str]]: | |
| """Prepare messages for the Groq API.""" | |
| messages = [] | |
| # Add system message if provided | |
| if "system_message" in context: | |
| messages.append({ | |
| "role": "system", | |
| "content": context["system_message"] | |
| }) | |
| # Add chat history if provided | |
| if "chat_history" in context: | |
| messages.extend(context["chat_history"]) | |
| # Add the current query | |
| messages.append({ | |
| "role": "user", | |
| "content": query | |
| }) | |
| return messages | |