Spaces:
Running
Running
| # Standard Library Imports | |
| import json | |
| from typing import List, Tuple | |
| # Third-Party Library Imports | |
| from sqlalchemy.ext.asyncio import AsyncSession | |
| # Local Application Imports | |
| from src.common import ( | |
| ComparisonType, | |
| LeaderboardEntry, | |
| OptionKey, | |
| OptionMap, | |
| TTSProviderName, | |
| VotingResults, | |
| constants, | |
| logger, | |
| ) | |
| from src.database import ( | |
| AsyncDBSessionMaker, | |
| create_vote, | |
| get_head_to_head_battle_stats, | |
| get_head_to_head_win_rate_stats, | |
| get_leaderboard_stats, | |
| ) | |
| class VotingService: | |
| """ | |
| Service for handling all database interactions related to voting and leaderboards. | |
| Encapsulates logic for submitting votes and retrieving formatted leaderboard statistics. | |
| """ | |
| def __init__(self, db_session_maker: AsyncDBSessionMaker): | |
| """ | |
| Initializes the VotingService. | |
| Args: | |
| db_session_maker: An asynchronous database session factory. | |
| """ | |
| self.db_session_maker: AsyncDBSessionMaker = db_session_maker | |
| logger.debug("VotingService initialized.") | |
| async def _create_db_session(self) -> AsyncSession | None: | |
| """ | |
| Creates a new database session, returning None if it's a dummy session. | |
| Returns: | |
| An active AsyncSession or None if using a dummy session factory. | |
| """ | |
| session = self.db_session_maker() | |
| # Check for a dummy session marker if your factory provides one | |
| is_dummy_session = getattr(session, "is_dummy", False) | |
| if is_dummy_session: | |
| logger.debug("Using dummy DB session; operations will be skipped.") | |
| # Ensure dummy sessions are also closed if they have resources | |
| if hasattr(session, "close"): | |
| await session.close() | |
| return None | |
| logger.debug("Created new DB session.") | |
| return session | |
| def _determine_comparison_type(self, provider_a: TTSProviderName, provider_b: TTSProviderName) -> ComparisonType: | |
| """ | |
| Determine the comparison type based on the given TTS provider names. | |
| Args: | |
| provider_a (TTSProviderName): The first TTS provider. | |
| provider_b (TTSProviderName): The second TTS provider. | |
| Returns: | |
| ComparisonType: The determined comparison type. | |
| Raises: | |
| ValueError: If the combination of providers is not recognized. | |
| """ | |
| if provider_a == constants.HUME_AI and provider_b == constants.HUME_AI: | |
| return constants.HUME_TO_HUME | |
| providers = (provider_a, provider_b) | |
| if constants.HUME_AI in providers and constants.ELEVENLABS in providers: | |
| return constants.HUME_TO_ELEVENLABS | |
| if constants.HUME_AI in providers and constants.OPENAI in providers: | |
| return constants.HUME_TO_OPENAI | |
| if constants.ELEVENLABS in providers and constants.OPENAI in providers: | |
| return constants.OPENAI_TO_ELEVENLABS | |
| raise ValueError(f"Invalid provider combination: {provider_a}, {provider_b}") | |
| async def _persist_vote(self, voting_results: VotingResults) -> None: | |
| """ | |
| Persists a vote record in the database using a dedicated session. | |
| Handles session creation, commit, rollback, and closure. Logs errors internally. | |
| Args: | |
| voting_results: A dictionary containing the vote details. | |
| """ | |
| session = await self._create_db_session() | |
| if session is None: | |
| logger.info("Skipping vote persistence (dummy session).") | |
| self._log_voting_results(voting_results) | |
| return | |
| try: | |
| self._log_voting_results(voting_results) | |
| await create_vote(session, voting_results) | |
| logger.info("Vote successfully persisted.") | |
| except Exception as e: | |
| logger.error(f"Failed to persist vote record: {e}", exc_info=True) | |
| finally: | |
| await session.close() | |
| logger.debug("DB session closed after persisting vote.") | |
| def _log_voting_results(self, voting_results: VotingResults) -> None: | |
| """Logs the full voting results dictionary.""" | |
| try: | |
| logger.info("Voting results:\n%s", json.dumps(voting_results, indent=4, default=str)) | |
| except TypeError: | |
| logger.error("Could not serialize voting results for logging.") | |
| logger.info(f"Voting results (raw): {voting_results}") | |
| def _format_leaderboard_data(self, leaderboard_data_raw: List[LeaderboardEntry]) -> List[List[str]]: | |
| """Formats raw leaderboard entries into HTML strings for the UI table.""" | |
| formatted_data = [] | |
| for rank, provider, model, win_rate, votes in leaderboard_data_raw: | |
| provider_info = constants.TTS_PROVIDER_LINKS.get(provider, {}) | |
| provider_link = provider_info.get("provider_link", "#") | |
| model_link = provider_info.get("model_link", "#") | |
| formatted_data.append([ | |
| f'<p style="text-align: center;">{rank}</p>', | |
| f'<a href="{provider_link}" target="_blank" class="provider-link">{provider}</a>', | |
| f'<a href="{model_link}" target="_blank" class="provider-link">{model}</a>', | |
| f'<p style="text-align: center;">{win_rate}</p>', | |
| f'<p style="text-align: center;">{votes}</p>', | |
| ]) | |
| return formatted_data | |
| def _format_battle_counts_data(self, battle_counts_data_raw: List[List[str]]) -> List[List[str]]: | |
| """Formats raw battle counts into an HTML matrix for the UI.""" | |
| battle_counts_dict = {item[0]: str(item[1]) for item in battle_counts_data_raw} | |
| providers = constants.TTS_PROVIDERS | |
| formatted_matrix: List[List[str]] = [] | |
| for row_provider in providers: | |
| row = [f'<p style="padding-left: 8px;"><strong>{row_provider}</strong></p>'] | |
| for col_provider in providers: | |
| if row_provider == col_provider: | |
| cell_value = "-" | |
| else: | |
| comparison_key = self._determine_comparison_type(row_provider, col_provider) | |
| cell_value = battle_counts_dict.get(comparison_key, "0") | |
| row.append(f'<p style="text-align: center;">{cell_value}</p>') | |
| formatted_matrix.append(row) | |
| return formatted_matrix | |
| def _format_win_rate_data(self, win_rate_data_raw: List[List[str]]) -> List[List[str]]: | |
| """Formats raw win rates into an HTML matrix for the UI.""" | |
| # win_rate_data_raw expected as [comparison_type, first_win_rate_str, second_win_rate_str] | |
| win_rates = {} | |
| for comparison_type, first_win_rate, second_win_rate in win_rate_data_raw: | |
| # Comparison type should already be canonical 'ProviderA - ProviderB' | |
| try: | |
| provider1, provider2 = comparison_type.split(" - ") | |
| win_rates[(provider1, provider2)] = first_win_rate | |
| win_rates[(provider2, provider1)] = second_win_rate | |
| except ValueError: | |
| logger.warning(f"Could not parse comparison_type '{comparison_type}' in win rate data.") | |
| continue # Skip malformed entry | |
| providers = constants.TTS_PROVIDERS | |
| formatted_matrix: List[List[str]] = [] | |
| for row_provider in providers: | |
| row = [f'<p style="padding-left: 8px;"><strong>{row_provider}</strong></p>'] | |
| for col_provider in providers: | |
| cell_value = "-" if row_provider == col_provider else win_rates.get((row_provider, col_provider), "0%") | |
| row.append(f'<p style="text-align: center;">{cell_value}</p>') | |
| formatted_matrix.append(row) | |
| return formatted_matrix | |
| async def get_formatted_leaderboard_data(self) -> Tuple[ | |
| List[List[str]], | |
| List[List[str]], | |
| List[List[str]], | |
| ]: | |
| """ | |
| Fetches raw leaderboard stats and formats them for UI display. | |
| Retrieves overall rankings, battle counts, and win rates, then formats | |
| them into HTML strings suitable for Gradio DataFrames. | |
| Returns: | |
| A tuple containing formatted lists of lists for: | |
| - Leaderboard rankings table | |
| - Battle counts matrix | |
| - Win rate matrix | |
| Returns empty lists ([[]], [[]], [[]]) on failure. | |
| """ | |
| session = await self._create_db_session() | |
| if session is None: | |
| logger.info("Skipping leaderboard fetch (dummy session).") | |
| return [[]], [[]], [[]] | |
| try: | |
| # Fetch raw data using underlying CRUD functions | |
| leaderboard_data_raw = await get_leaderboard_stats(session) | |
| battle_counts_data_raw = await get_head_to_head_battle_stats(session) | |
| win_rate_data_raw = await get_head_to_head_win_rate_stats(session) | |
| logger.debug("Fetched raw leaderboard data successfully.") | |
| # Format the data | |
| leaderboard_data = self._format_leaderboard_data(leaderboard_data_raw) | |
| battle_counts_data = self._format_battle_counts_data(battle_counts_data_raw) | |
| win_rate_data = self._format_win_rate_data(win_rate_data_raw) | |
| return leaderboard_data, battle_counts_data, win_rate_data | |
| except Exception as e: | |
| logger.error(f"Failed to fetch and format leaderboard data: {e}", exc_info=True) | |
| return [[]], [[]], [[]] # Return empty structure on error | |
| finally: | |
| await session.close() | |
| logger.debug("DB session closed after fetching leaderboard data.") | |
| async def submit_vote( | |
| self, | |
| option_map: OptionMap, | |
| selected_option: OptionKey, | |
| text_modified: bool, | |
| character_description: str, | |
| text: str, | |
| ) -> None: | |
| """ | |
| Constructs and persists a vote record based on user selection and context. | |
| This method is designed to be called safely from background tasks, handling all internal exceptions. | |
| Args: | |
| option_map: Mapping of comparison data and TTS options. | |
| selected_option: The option key ('option_a' or 'option_b') selected by the user. | |
| text_modified: Indicates if the text was custom vs. generated. | |
| character_description: Description used for TTS generation. | |
| text: The text synthesized. | |
| """ | |
| try: | |
| provider_a: TTSProviderName = option_map[constants.OPTION_A_KEY]["provider"] | |
| provider_b: TTSProviderName = option_map[constants.OPTION_B_KEY]["provider"] | |
| comparison_type: ComparisonType = self._determine_comparison_type(provider_a, provider_b) | |
| voting_results: VotingResults = { | |
| "comparison_type": comparison_type, | |
| "winning_provider": option_map[selected_option]["provider"], | |
| "winning_option": selected_option, | |
| "option_a_provider": provider_a, | |
| "option_b_provider": provider_b, | |
| "option_a_generation_id": option_map[constants.OPTION_A_KEY]["generation_id"], | |
| "option_b_generation_id": option_map[constants.OPTION_B_KEY]["generation_id"], | |
| "character_description": character_description, | |
| "text": text, | |
| "is_custom_text": text_modified, | |
| } | |
| await self._persist_vote(voting_results) | |
| except KeyError as e: | |
| logger.error( | |
| f"Missing key in option_map during vote submission: {e}. OptionMap: {option_map}", | |
| exc_info=True | |
| ) | |
| except Exception as e: | |
| logger.error(f"Unexpected error in submit_vote: {e}", exc_info=True) | |