Spaces:
Running
Running
| from datasets import load_dataset | |
| from typing import List, Optional, Dict, Any | |
| from datetime import datetime | |
| from models import ArticleResponse, ArticleDetail, Argument, FiltersResponse | |
| from collections import Counter | |
| from functools import lru_cache | |
| from whoosh import index | |
| from whoosh.fields import Schema, TEXT, ID | |
| from whoosh.qparser import QueryParser | |
| from whoosh.filedb.filestore import RamStorage | |
| from dateutil import parser as date_parser | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| # Constants | |
| SEARCH_CACHE_MAX_SIZE = 1000 | |
| LABOR_SCORE_WEIGHT = 0.1 # Weight for labor score in relevance calculation | |
| DATE_RANGE_START = "2022-01-01" | |
| DATE_RANGE_END = "2025-12-31" | |
| class DataLoader: | |
| """ | |
| Handles loading, indexing, and searching of AI labor economy articles. | |
| Uses Whoosh for full-text search and maintains in-memory data structures | |
| for fast filtering and pagination. | |
| """ | |
| def __init__(self): | |
| self.dataset = None | |
| self.articles = [] | |
| self.articles_by_id = {} # ID -> article mapping | |
| self.filter_options = None | |
| # Initialize Whoosh search index for full-text search | |
| self.search_schema = Schema( | |
| id=ID(stored=True), | |
| title=TEXT(stored=False), | |
| summary=TEXT(stored=False), | |
| content=TEXT(stored=False) # Combined title + summary for search | |
| ) | |
| # Create in-memory index using RamStorage | |
| storage = RamStorage() | |
| self.search_index = storage.create_index(self.search_schema) | |
| self.query_parser = QueryParser("content", self.search_schema) | |
| # Dense retrieval components (lazy-loaded for efficiency) | |
| self.embeddings = None # Article embeddings from dataset | |
| self.embedding_model = None # SentenceTransformer model | |
| self.model_path = "ibm-granite/granite-embedding-english-r2" | |
| # Note: Using lru_cache for search caching instead of manual cache management | |
| async def load_dataset(self): | |
| """Load and process the HuggingFace dataset""" | |
| # Load dataset | |
| self.dataset = load_dataset("yjernite/ai-economy-labor-articles-annotated-embed", split="train") | |
| # Convert to list of dicts for easier processing | |
| self.articles = [] | |
| # Prepare Whoosh index writer | |
| writer = self.search_index.writer() | |
| for i, row in enumerate(self.dataset): | |
| # Parse date using dateutil (more flexible than pandas) | |
| date = date_parser.parse(row['date']) if isinstance(row['date'], str) else row['date'] | |
| # Parse arguments | |
| arguments = [] | |
| if row.get('arguments'): | |
| for arg in row['arguments']: | |
| arguments.append(Argument( | |
| argument_quote=arg.get('argument_quote', []), | |
| argument_summary=arg.get('argument_summary', ''), | |
| argument_source=arg.get('argument_source', ''), | |
| argument_type=arg.get('argument_type', ''), | |
| )) | |
| article = { | |
| 'id': i, | |
| 'title': row.get('title', ''), | |
| 'source': row.get('source', ''), | |
| 'url': row.get('url', ''), | |
| 'date': date, | |
| 'summary': row.get('summary', ''), | |
| 'ai_labor_relevance': row.get('ai_labor_relevance', 0), | |
| 'document_type': row.get('document_type', ''), | |
| 'author_type': row.get('author_type', ''), | |
| 'document_topics': row.get('document_topics', []), | |
| 'text': row.get('text', ''), | |
| 'arguments': arguments, | |
| } | |
| self.articles.append(article) | |
| self.articles_by_id[i] = article | |
| # Add to search index | |
| search_content = f"{article['title']} {article['summary']}" | |
| writer.add_document( | |
| id=str(i), | |
| title=article['title'], | |
| summary=article['summary'], | |
| content=search_content | |
| ) | |
| # Commit search index | |
| writer.commit() | |
| print(f"DEBUG: Search index populated with {len(self.articles)} articles") | |
| # Load pre-computed embeddings for dense retrieval | |
| print("DEBUG: Loading pre-computed embeddings...") | |
| raw_embeddings = np.array(self.dataset['embeddings-granite']) | |
| # Normalize embeddings for cosine similarity | |
| self.embeddings = raw_embeddings / np.linalg.norm(raw_embeddings, axis=1, keepdims=True) | |
| print(f"DEBUG: Loaded {len(self.embeddings)} article embeddings") | |
| # Build filter options | |
| self._build_filter_options() | |
| def _build_filter_options(self): | |
| """Build available filter options from the dataset""" | |
| document_types = sorted(set(article['document_type'] for article in self.articles if article['document_type'])) | |
| author_types = sorted(set(article['author_type'] for article in self.articles if article['author_type'])) | |
| # Flatten all topics | |
| all_topics = [] | |
| for article in self.articles: | |
| if article['document_topics']: | |
| all_topics.extend(article['document_topics']) | |
| topics = sorted(set(all_topics)) | |
| # Date range - fixed for research period | |
| min_date = DATE_RANGE_START | |
| max_date = DATE_RANGE_END | |
| # Relevance range | |
| relevances = [article['ai_labor_relevance'] for article in self.articles if article['ai_labor_relevance'] is not None] | |
| min_relevance = min(relevances) if relevances else 0 | |
| max_relevance = max(relevances) if relevances else 10 | |
| self.filter_options = FiltersResponse( | |
| document_types=document_types, | |
| author_types=author_types, | |
| topics=topics, | |
| date_range={"min_date": min_date, "max_date": max_date}, | |
| relevance_range={"min_relevance": min_relevance, "max_relevance": max_relevance} | |
| ) | |
| def get_filter_options(self) -> FiltersResponse: | |
| """Get all available filter options""" | |
| return self.filter_options | |
| def _filter_articles( | |
| self, | |
| document_types: Optional[List[str]] = None, | |
| author_types: Optional[List[str]] = None, | |
| min_relevance: Optional[float] = None, | |
| max_relevance: Optional[float] = None, | |
| start_date: Optional[str] = None, | |
| end_date: Optional[str] = None, | |
| topics: Optional[List[str]] = None, | |
| search_query: Optional[str] = None, | |
| search_type: Optional[str] = None, | |
| ) -> List[Dict[str, Any]]: | |
| """Filter articles based on criteria""" | |
| filtered = self.articles | |
| if document_types: | |
| filtered = [a for a in filtered if a['document_type'] in document_types] | |
| if author_types: | |
| filtered = [a for a in filtered if a['author_type'] in author_types] | |
| if min_relevance is not None: | |
| filtered = [a for a in filtered if a['ai_labor_relevance'] >= min_relevance] | |
| if max_relevance is not None: | |
| filtered = [a for a in filtered if a['ai_labor_relevance'] <= max_relevance] | |
| if start_date: | |
| start_dt = date_parser.parse(start_date) | |
| filtered = [a for a in filtered if a['date'] >= start_dt] | |
| if end_date: | |
| end_dt = date_parser.parse(end_date) | |
| filtered = [a for a in filtered if a['date'] <= end_dt] | |
| if topics: | |
| filtered = [a for a in filtered if any(topic in a['document_topics'] for topic in topics)] | |
| if search_query: | |
| print(f"DEBUG: Applying search filter for query: '{search_query}' with type: '{search_type}'") | |
| if search_type == 'dense': | |
| # For dense search, get similarity scores for all articles | |
| dense_scores = self._dense_search_all_articles(search_query) | |
| dense_score_dict = {idx: score for idx, score in dense_scores} | |
| # Attach dense scores to filtered articles and filter by similarity threshold | |
| filtered_with_scores = [] | |
| for article in filtered: | |
| article_id = article['id'] | |
| if article_id in dense_score_dict: | |
| # Create a copy to avoid modifying the original | |
| article_copy = article.copy() | |
| article_copy['dense_score'] = dense_score_dict[article_id] | |
| # Only include articles with meaningful similarity (> 0.8) | |
| if dense_score_dict[article_id] > 0.8: | |
| filtered_with_scores.append(article_copy) | |
| filtered = filtered_with_scores | |
| print(f"DEBUG: After dense search filtering: {len(filtered)} articles remaining") | |
| else: | |
| # Existing exact search logic - inline the matching check | |
| search_results = self._search_articles(search_query, search_type) | |
| filtered = [a for a in filtered if a.get('id') in search_results] | |
| print(f"DEBUG: After exact search filtering: {len(filtered)} articles remaining") | |
| return filtered | |
| def _search_articles(self, search_query: str, search_type: Optional[str] = None) -> Dict[int, float]: | |
| """Search articles using Whoosh and return article_id -> score mapping | |
| Note: Dense search is handled separately in _filter_articles method. | |
| This method only handles exact/Whoosh search. | |
| """ | |
| if not search_query: | |
| return {} | |
| # Use cached Whoosh search (lru_cache handles caching automatically) | |
| return self._cached_whoosh_search(search_query) | |
| def _cached_whoosh_search(self, search_query: str) -> Dict[int, float]: | |
| """Cached version of Whoosh search using lru_cache""" | |
| return self._whoosh_search(search_query) | |
| def _whoosh_search(self, search_query: str) -> Dict[int, float]: | |
| """Perform search using Whoosh index""" | |
| try: | |
| with self.search_index.searcher() as searcher: | |
| # Parse query - Whoosh handles tokenization automatically | |
| query = self.query_parser.parse(search_query) | |
| results = searcher.search(query, limit=None) # Get all results | |
| print(f"DEBUG: Search query '{search_query}' found {len(results)} results") | |
| # Return mapping of article_id -> normalized score | |
| article_scores = {} | |
| max_score = max((r.score for r in results), default=1.0) | |
| for result in results: | |
| article_id = int(result['id']) | |
| # Normalize score to 0-1 range | |
| normalized_score = result.score / max_score if max_score > 0 else 0.0 | |
| article_scores[article_id] = normalized_score | |
| print(f"DEBUG: Returning {len(article_scores)} scored articles") | |
| return article_scores | |
| except Exception as e: | |
| print(f"Search error: {e}") | |
| return {} | |
| def _initialize_embedding_model(self): | |
| """Lazy initialization of embedding model (CPU-only)""" | |
| if self.embedding_model is None: | |
| print("DEBUG: Initializing embedding model (CPU-only)...") | |
| # Force CPU usage and disable problematic features | |
| import os | |
| os.environ['CUDA_VISIBLE_DEVICES'] = '' | |
| # Initialize model with CPU device and specific config | |
| self.embedding_model = SentenceTransformer( | |
| self.model_path, | |
| device='cpu', | |
| model_kwargs={ | |
| 'dtype': 'float32', # Fixed deprecation warning | |
| 'attn_implementation': 'eager' # Use eager attention instead of flash attention | |
| } | |
| ) | |
| print("DEBUG: Embedding model initialized") | |
| # Cache encoded queries (smaller cache for this) | |
| def _encode_query_cached(self, query: str) -> tuple: | |
| """Cache-friendly version of query encoding (returns tuple for hashing)""" | |
| embedding = self._encode_query_internal(query) | |
| return tuple(embedding.tolist()) # Convert to tuple for caching | |
| def _encode_query(self, query: str) -> np.ndarray: | |
| """Encode a query string into an embedding vector""" | |
| cached_result = self._encode_query_cached(query) | |
| return np.array(cached_result) # Convert back to numpy array | |
| def _encode_query_internal(self, query: str) -> np.ndarray: | |
| """Internal method that does the actual encoding""" | |
| self._initialize_embedding_model() | |
| query_embedding = self.embedding_model.encode([query]) | |
| # Normalize for cosine similarity | |
| return query_embedding[0] / np.linalg.norm(query_embedding[0]) | |
| def _dense_search_all_articles(self, query: str, k: int = None) -> List[tuple]: | |
| """ | |
| Perform dense retrieval across ALL articles and return (index, score) pairs. | |
| This computes all similarities upfront for maximum flexibility. | |
| """ | |
| if self.embeddings is None: | |
| print("ERROR: Embeddings not loaded") | |
| return [] | |
| print(f"DEBUG: Performing dense search for query: '{query}'") | |
| # Encode query | |
| query_embed = self._encode_query(query) | |
| # Compute similarities with ALL articles | |
| similarities = np.dot(self.embeddings, query_embed) | |
| # Get all articles with their similarity scores | |
| article_scores = [(i, float(similarities[i])) for i in range(len(similarities))] | |
| # Sort by similarity (highest first) | |
| article_scores.sort(key=lambda x: x[1], reverse=True) | |
| # Apply k limit if specified | |
| if k is not None: | |
| article_scores = article_scores[:k] | |
| print(f"DEBUG: Dense search found {len(article_scores)} scored articles") | |
| return article_scores | |
| def _calculate_query_score(self, article: Dict[str, Any], search_query: str, search_type: Optional[str] = None) -> float: | |
| """Calculate query relevance score based on search type""" | |
| if not search_query: | |
| return 0.0 | |
| if search_type == 'dense': | |
| # For dense search, return the pre-computed similarity score | |
| return article.get('dense_score', 0.0) | |
| else: | |
| # Existing exact search logic using Whoosh | |
| search_results = self._search_articles(search_query, search_type) | |
| article_id = article.get('id') | |
| # Return Whoosh score or 0.0 if not found | |
| return search_results.get(article_id, 0.0) | |
| def _sort_by_relevance(self, articles: List[Dict[str, Any]], search_query: str, search_type: str = 'exact') -> List[Dict[str, Any]]: | |
| """Sort articles by relevance score (query score + labor score)""" | |
| def relevance_key(article): | |
| query_score = self._calculate_query_score(article, search_query, search_type) | |
| labor_score = article.get('ai_labor_relevance', 0) / 10.0 # Normalize to 0-1 | |
| # Prioritize query score but include labor score as tiebreaker | |
| return query_score + (labor_score * LABOR_SCORE_WEIGHT) | |
| return sorted(articles, key=relevance_key, reverse=True) | |
| def get_articles( | |
| self, | |
| page: int = 1, | |
| limit: int = 20, | |
| **filters | |
| ) -> List[ArticleResponse]: | |
| """Get filtered and paginated articles""" | |
| # Extract sort_by, search_query, and search_type for special handling | |
| sort_by = filters.pop('sort_by', 'date') | |
| search_query = filters.get('search_query') | |
| search_type = filters.get('search_type', 'exact') | |
| filtered_articles = self._filter_articles(**filters) | |
| # Apply sorting | |
| if sort_by == 'score' and search_query: | |
| # Sort by query relevance score descending, then by labor score | |
| filtered_articles = self._sort_by_relevance(filtered_articles, search_query, search_type) | |
| else: | |
| # Sort by date (oldest first) - default | |
| filtered_articles.sort(key=lambda x: x['date'], reverse=False) | |
| # Paginate | |
| start_idx = (page - 1) * limit | |
| end_idx = start_idx + limit | |
| page_articles = filtered_articles[start_idx:end_idx] | |
| # Convert to response models - use the original ID from the sorted/filtered results | |
| return [ | |
| ArticleResponse( | |
| id=article['id'], | |
| title=article['title'], | |
| source=article['source'], | |
| url=article['url'], | |
| date=article['date'], | |
| summary=article['summary'], | |
| ai_labor_relevance=article['ai_labor_relevance'], | |
| query_score=self._calculate_query_score(article, search_query or '', search_type), | |
| document_type=article['document_type'], | |
| author_type=article['author_type'], | |
| document_topics=article['document_topics'], | |
| ) | |
| for article in page_articles | |
| ] | |
| def get_articles_count(self, **filters) -> int: | |
| """Get count of articles matching filters""" | |
| filtered_articles = self._filter_articles(**filters) | |
| return len(filtered_articles) | |
| def get_filter_counts(self, filter_type: str, **filters) -> Dict[str, int]: | |
| """Get counts for each option in a specific filter type, given other filters""" | |
| # Remove the current filter type from filters to avoid circular filtering | |
| filters_copy = filters.copy() | |
| filters_copy.pop(filter_type, None) | |
| # Get base filtered articles (without the current filter type) | |
| base_filtered = self._filter_articles(**filters_copy) | |
| # Extract values based on filter type and count with Counter | |
| if filter_type == 'document_types': | |
| values = [article.get('document_type') for article in base_filtered | |
| if article.get('document_type')] | |
| elif filter_type == 'author_types': | |
| values = [article.get('author_type') for article in base_filtered | |
| if article.get('author_type')] | |
| elif filter_type == 'topics': | |
| values = [topic for article in base_filtered | |
| for topic in article.get('document_topics', []) if topic] | |
| else: | |
| return {} | |
| return dict(Counter(values)) | |
| def get_article_detail(self, article_id: int) -> ArticleDetail: | |
| """Get detailed article by ID""" | |
| if article_id not in self.articles_by_id: | |
| raise ValueError(f"Article ID {article_id} not found") | |
| article = self.articles_by_id[article_id] | |
| return ArticleDetail( | |
| id=article['id'], | |
| title=article['title'], | |
| source=article['source'], | |
| url=article['url'], | |
| date=article['date'], | |
| summary=article['summary'], | |
| ai_labor_relevance=article['ai_labor_relevance'], | |
| query_score=0.0, # Detail view doesn't have search context | |
| document_type=article['document_type'], | |
| author_type=article['author_type'], | |
| document_topics=article['document_topics'], | |
| text=article['text'], | |
| arguments=article['arguments'], | |
| ) | |