labor-archive-backend / data_loader.py
yjernite's picture
yjernite HF Staff
Upload 5 files
1b21566 verified
from datasets import load_dataset
from typing import List, Optional, Dict, Any
from datetime import datetime
from models import ArticleResponse, ArticleDetail, Argument, FiltersResponse
from collections import Counter
from functools import lru_cache
from whoosh import index
from whoosh.fields import Schema, TEXT, ID
from whoosh.qparser import QueryParser
from whoosh.filedb.filestore import RamStorage
from dateutil import parser as date_parser
import numpy as np
from sentence_transformers import SentenceTransformer
# Constants
SEARCH_CACHE_MAX_SIZE = 1000
LABOR_SCORE_WEIGHT = 0.1 # Weight for labor score in relevance calculation
DATE_RANGE_START = "2022-01-01"
DATE_RANGE_END = "2025-12-31"
class DataLoader:
"""
Handles loading, indexing, and searching of AI labor economy articles.
Uses Whoosh for full-text search and maintains in-memory data structures
for fast filtering and pagination.
"""
def __init__(self):
self.dataset = None
self.articles = []
self.articles_by_id = {} # ID -> article mapping
self.filter_options = None
# Initialize Whoosh search index for full-text search
self.search_schema = Schema(
id=ID(stored=True),
title=TEXT(stored=False),
summary=TEXT(stored=False),
content=TEXT(stored=False) # Combined title + summary for search
)
# Create in-memory index using RamStorage
storage = RamStorage()
self.search_index = storage.create_index(self.search_schema)
self.query_parser = QueryParser("content", self.search_schema)
# Dense retrieval components (lazy-loaded for efficiency)
self.embeddings = None # Article embeddings from dataset
self.embedding_model = None # SentenceTransformer model
self.model_path = "ibm-granite/granite-embedding-english-r2"
# Note: Using lru_cache for search caching instead of manual cache management
async def load_dataset(self):
"""Load and process the HuggingFace dataset"""
# Load dataset
self.dataset = load_dataset("yjernite/ai-economy-labor-articles-annotated-embed", split="train")
# Convert to list of dicts for easier processing
self.articles = []
# Prepare Whoosh index writer
writer = self.search_index.writer()
for i, row in enumerate(self.dataset):
# Parse date using dateutil (more flexible than pandas)
date = date_parser.parse(row['date']) if isinstance(row['date'], str) else row['date']
# Parse arguments
arguments = []
if row.get('arguments'):
for arg in row['arguments']:
arguments.append(Argument(
argument_quote=arg.get('argument_quote', []),
argument_summary=arg.get('argument_summary', ''),
argument_source=arg.get('argument_source', ''),
argument_type=arg.get('argument_type', ''),
))
article = {
'id': i,
'title': row.get('title', ''),
'source': row.get('source', ''),
'url': row.get('url', ''),
'date': date,
'summary': row.get('summary', ''),
'ai_labor_relevance': row.get('ai_labor_relevance', 0),
'document_type': row.get('document_type', ''),
'author_type': row.get('author_type', ''),
'document_topics': row.get('document_topics', []),
'text': row.get('text', ''),
'arguments': arguments,
}
self.articles.append(article)
self.articles_by_id[i] = article
# Add to search index
search_content = f"{article['title']} {article['summary']}"
writer.add_document(
id=str(i),
title=article['title'],
summary=article['summary'],
content=search_content
)
# Commit search index
writer.commit()
print(f"DEBUG: Search index populated with {len(self.articles)} articles")
# Load pre-computed embeddings for dense retrieval
print("DEBUG: Loading pre-computed embeddings...")
raw_embeddings = np.array(self.dataset['embeddings-granite'])
# Normalize embeddings for cosine similarity
self.embeddings = raw_embeddings / np.linalg.norm(raw_embeddings, axis=1, keepdims=True)
print(f"DEBUG: Loaded {len(self.embeddings)} article embeddings")
# Build filter options
self._build_filter_options()
def _build_filter_options(self):
"""Build available filter options from the dataset"""
document_types = sorted(set(article['document_type'] for article in self.articles if article['document_type']))
author_types = sorted(set(article['author_type'] for article in self.articles if article['author_type']))
# Flatten all topics
all_topics = []
for article in self.articles:
if article['document_topics']:
all_topics.extend(article['document_topics'])
topics = sorted(set(all_topics))
# Date range - fixed for research period
min_date = DATE_RANGE_START
max_date = DATE_RANGE_END
# Relevance range
relevances = [article['ai_labor_relevance'] for article in self.articles if article['ai_labor_relevance'] is not None]
min_relevance = min(relevances) if relevances else 0
max_relevance = max(relevances) if relevances else 10
self.filter_options = FiltersResponse(
document_types=document_types,
author_types=author_types,
topics=topics,
date_range={"min_date": min_date, "max_date": max_date},
relevance_range={"min_relevance": min_relevance, "max_relevance": max_relevance}
)
def get_filter_options(self) -> FiltersResponse:
"""Get all available filter options"""
return self.filter_options
def _filter_articles(
self,
document_types: Optional[List[str]] = None,
author_types: Optional[List[str]] = None,
min_relevance: Optional[float] = None,
max_relevance: Optional[float] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
topics: Optional[List[str]] = None,
search_query: Optional[str] = None,
search_type: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""Filter articles based on criteria"""
filtered = self.articles
if document_types:
filtered = [a for a in filtered if a['document_type'] in document_types]
if author_types:
filtered = [a for a in filtered if a['author_type'] in author_types]
if min_relevance is not None:
filtered = [a for a in filtered if a['ai_labor_relevance'] >= min_relevance]
if max_relevance is not None:
filtered = [a for a in filtered if a['ai_labor_relevance'] <= max_relevance]
if start_date:
start_dt = date_parser.parse(start_date)
filtered = [a for a in filtered if a['date'] >= start_dt]
if end_date:
end_dt = date_parser.parse(end_date)
filtered = [a for a in filtered if a['date'] <= end_dt]
if topics:
filtered = [a for a in filtered if any(topic in a['document_topics'] for topic in topics)]
if search_query:
print(f"DEBUG: Applying search filter for query: '{search_query}' with type: '{search_type}'")
if search_type == 'dense':
# For dense search, get similarity scores for all articles
dense_scores = self._dense_search_all_articles(search_query)
dense_score_dict = {idx: score for idx, score in dense_scores}
# Attach dense scores to filtered articles and filter by similarity threshold
filtered_with_scores = []
for article in filtered:
article_id = article['id']
if article_id in dense_score_dict:
# Create a copy to avoid modifying the original
article_copy = article.copy()
article_copy['dense_score'] = dense_score_dict[article_id]
# Only include articles with meaningful similarity (> 0.8)
if dense_score_dict[article_id] > 0.8:
filtered_with_scores.append(article_copy)
filtered = filtered_with_scores
print(f"DEBUG: After dense search filtering: {len(filtered)} articles remaining")
else:
# Existing exact search logic - inline the matching check
search_results = self._search_articles(search_query, search_type)
filtered = [a for a in filtered if a.get('id') in search_results]
print(f"DEBUG: After exact search filtering: {len(filtered)} articles remaining")
return filtered
def _search_articles(self, search_query: str, search_type: Optional[str] = None) -> Dict[int, float]:
"""Search articles using Whoosh and return article_id -> score mapping
Note: Dense search is handled separately in _filter_articles method.
This method only handles exact/Whoosh search.
"""
if not search_query:
return {}
# Use cached Whoosh search (lru_cache handles caching automatically)
return self._cached_whoosh_search(search_query)
@lru_cache(maxsize=SEARCH_CACHE_MAX_SIZE)
def _cached_whoosh_search(self, search_query: str) -> Dict[int, float]:
"""Cached version of Whoosh search using lru_cache"""
return self._whoosh_search(search_query)
def _whoosh_search(self, search_query: str) -> Dict[int, float]:
"""Perform search using Whoosh index"""
try:
with self.search_index.searcher() as searcher:
# Parse query - Whoosh handles tokenization automatically
query = self.query_parser.parse(search_query)
results = searcher.search(query, limit=None) # Get all results
print(f"DEBUG: Search query '{search_query}' found {len(results)} results")
# Return mapping of article_id -> normalized score
article_scores = {}
max_score = max((r.score for r in results), default=1.0)
for result in results:
article_id = int(result['id'])
# Normalize score to 0-1 range
normalized_score = result.score / max_score if max_score > 0 else 0.0
article_scores[article_id] = normalized_score
print(f"DEBUG: Returning {len(article_scores)} scored articles")
return article_scores
except Exception as e:
print(f"Search error: {e}")
return {}
def _initialize_embedding_model(self):
"""Lazy initialization of embedding model (CPU-only)"""
if self.embedding_model is None:
print("DEBUG: Initializing embedding model (CPU-only)...")
# Force CPU usage and disable problematic features
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
# Initialize model with CPU device and specific config
self.embedding_model = SentenceTransformer(
self.model_path,
device='cpu',
model_kwargs={
'dtype': 'float32', # Fixed deprecation warning
'attn_implementation': 'eager' # Use eager attention instead of flash attention
}
)
print("DEBUG: Embedding model initialized")
@lru_cache(maxsize=100) # Cache encoded queries (smaller cache for this)
def _encode_query_cached(self, query: str) -> tuple:
"""Cache-friendly version of query encoding (returns tuple for hashing)"""
embedding = self._encode_query_internal(query)
return tuple(embedding.tolist()) # Convert to tuple for caching
def _encode_query(self, query: str) -> np.ndarray:
"""Encode a query string into an embedding vector"""
cached_result = self._encode_query_cached(query)
return np.array(cached_result) # Convert back to numpy array
def _encode_query_internal(self, query: str) -> np.ndarray:
"""Internal method that does the actual encoding"""
self._initialize_embedding_model()
query_embedding = self.embedding_model.encode([query])
# Normalize for cosine similarity
return query_embedding[0] / np.linalg.norm(query_embedding[0])
def _dense_search_all_articles(self, query: str, k: int = None) -> List[tuple]:
"""
Perform dense retrieval across ALL articles and return (index, score) pairs.
This computes all similarities upfront for maximum flexibility.
"""
if self.embeddings is None:
print("ERROR: Embeddings not loaded")
return []
print(f"DEBUG: Performing dense search for query: '{query}'")
# Encode query
query_embed = self._encode_query(query)
# Compute similarities with ALL articles
similarities = np.dot(self.embeddings, query_embed)
# Get all articles with their similarity scores
article_scores = [(i, float(similarities[i])) for i in range(len(similarities))]
# Sort by similarity (highest first)
article_scores.sort(key=lambda x: x[1], reverse=True)
# Apply k limit if specified
if k is not None:
article_scores = article_scores[:k]
print(f"DEBUG: Dense search found {len(article_scores)} scored articles")
return article_scores
def _calculate_query_score(self, article: Dict[str, Any], search_query: str, search_type: Optional[str] = None) -> float:
"""Calculate query relevance score based on search type"""
if not search_query:
return 0.0
if search_type == 'dense':
# For dense search, return the pre-computed similarity score
return article.get('dense_score', 0.0)
else:
# Existing exact search logic using Whoosh
search_results = self._search_articles(search_query, search_type)
article_id = article.get('id')
# Return Whoosh score or 0.0 if not found
return search_results.get(article_id, 0.0)
def _sort_by_relevance(self, articles: List[Dict[str, Any]], search_query: str, search_type: str = 'exact') -> List[Dict[str, Any]]:
"""Sort articles by relevance score (query score + labor score)"""
def relevance_key(article):
query_score = self._calculate_query_score(article, search_query, search_type)
labor_score = article.get('ai_labor_relevance', 0) / 10.0 # Normalize to 0-1
# Prioritize query score but include labor score as tiebreaker
return query_score + (labor_score * LABOR_SCORE_WEIGHT)
return sorted(articles, key=relevance_key, reverse=True)
def get_articles(
self,
page: int = 1,
limit: int = 20,
**filters
) -> List[ArticleResponse]:
"""Get filtered and paginated articles"""
# Extract sort_by, search_query, and search_type for special handling
sort_by = filters.pop('sort_by', 'date')
search_query = filters.get('search_query')
search_type = filters.get('search_type', 'exact')
filtered_articles = self._filter_articles(**filters)
# Apply sorting
if sort_by == 'score' and search_query:
# Sort by query relevance score descending, then by labor score
filtered_articles = self._sort_by_relevance(filtered_articles, search_query, search_type)
else:
# Sort by date (oldest first) - default
filtered_articles.sort(key=lambda x: x['date'], reverse=False)
# Paginate
start_idx = (page - 1) * limit
end_idx = start_idx + limit
page_articles = filtered_articles[start_idx:end_idx]
# Convert to response models - use the original ID from the sorted/filtered results
return [
ArticleResponse(
id=article['id'],
title=article['title'],
source=article['source'],
url=article['url'],
date=article['date'],
summary=article['summary'],
ai_labor_relevance=article['ai_labor_relevance'],
query_score=self._calculate_query_score(article, search_query or '', search_type),
document_type=article['document_type'],
author_type=article['author_type'],
document_topics=article['document_topics'],
)
for article in page_articles
]
def get_articles_count(self, **filters) -> int:
"""Get count of articles matching filters"""
filtered_articles = self._filter_articles(**filters)
return len(filtered_articles)
def get_filter_counts(self, filter_type: str, **filters) -> Dict[str, int]:
"""Get counts for each option in a specific filter type, given other filters"""
# Remove the current filter type from filters to avoid circular filtering
filters_copy = filters.copy()
filters_copy.pop(filter_type, None)
# Get base filtered articles (without the current filter type)
base_filtered = self._filter_articles(**filters_copy)
# Extract values based on filter type and count with Counter
if filter_type == 'document_types':
values = [article.get('document_type') for article in base_filtered
if article.get('document_type')]
elif filter_type == 'author_types':
values = [article.get('author_type') for article in base_filtered
if article.get('author_type')]
elif filter_type == 'topics':
values = [topic for article in base_filtered
for topic in article.get('document_topics', []) if topic]
else:
return {}
return dict(Counter(values))
def get_article_detail(self, article_id: int) -> ArticleDetail:
"""Get detailed article by ID"""
if article_id not in self.articles_by_id:
raise ValueError(f"Article ID {article_id} not found")
article = self.articles_by_id[article_id]
return ArticleDetail(
id=article['id'],
title=article['title'],
source=article['source'],
url=article['url'],
date=article['date'],
summary=article['summary'],
ai_labor_relevance=article['ai_labor_relevance'],
query_score=0.0, # Detail view doesn't have search context
document_type=article['document_type'],
author_type=article['author_type'],
document_topics=article['document_topics'],
text=article['text'],
arguments=article['arguments'],
)