Spaces:
Sleeping
Sleeping
| import logging | |
| from typing import Dict, Any, List, Optional | |
| from transformers import pipeline, AutoTokenizer | |
| import numpy as np | |
| import nltk | |
| from nltk.tokenize import sent_tokenize | |
| logger = logging.getLogger(__name__) | |
| class HeadlineAnalyzer: | |
| def __init__(self, use_ai: bool = True, model_registry: Optional[Any] = None): | |
| """ | |
| Initialize the analyzers for headline analysis. | |
| Args: | |
| use_ai: Boolean indicating whether to use AI-powered analysis (True) or traditional analysis (False) | |
| model_registry: Optional shared model registry for better performance | |
| """ | |
| self.use_ai = use_ai | |
| self.llm_available = False | |
| self.model_registry = model_registry | |
| if use_ai: | |
| try: | |
| if model_registry and model_registry.is_available: | |
| # Use shared models | |
| self.nli_pipeline = model_registry.nli | |
| self.zero_shot = model_registry.zero_shot | |
| self.tokenizer = AutoTokenizer.from_pretrained("roberta-large-mnli") | |
| self.llm_available = True | |
| logger.info("Using shared model pipelines for headline analysis") | |
| else: | |
| # Initialize own pipelines | |
| self.nli_pipeline = pipeline( | |
| "text-classification", | |
| model="roberta-large-mnli", | |
| batch_size=16 | |
| ) | |
| self.zero_shot = pipeline( | |
| "zero-shot-classification", | |
| model="facebook/bart-large-mnli", | |
| device=-1, | |
| batch_size=8 | |
| ) | |
| self.tokenizer = AutoTokenizer.from_pretrained("roberta-large-mnli") | |
| self.llm_available = True | |
| logger.info("Initialized dedicated model pipelines for headline analysis") | |
| self.max_length = 512 | |
| except Exception as e: | |
| logger.warning(f"Failed to initialize LLM pipelines: {str(e)}") | |
| self.llm_available = False | |
| else: | |
| logger.info("Initializing headline analyzer in traditional mode") | |
| def _split_content(self, headline: str, content: str) -> List[str]: | |
| """Split content into sections that fit within token limit.""" | |
| content_words = content.split() | |
| sections = [] | |
| current_section = [] | |
| # Account for headline and [SEP] token in the max length | |
| headline_tokens = len(self.tokenizer.encode(headline)) | |
| sep_tokens = len(self.tokenizer.encode("[SEP]")) - 2 | |
| max_content_tokens = self.max_length - headline_tokens - sep_tokens | |
| # Process words into sections with 4000 character chunks | |
| current_text = "" | |
| for word in content_words: | |
| if len(current_text) + len(word) + 1 <= 4000: | |
| current_text += " " + word | |
| else: | |
| sections.append(current_text.strip()) | |
| current_text = word | |
| if current_text: | |
| sections.append(current_text.strip()) | |
| return sections | |
| def _analyze_section(self, headline: str, section: str) -> Dict[str, Any]: | |
| """Analyze a single section for headline accuracy and sensationalism.""" | |
| try: | |
| logger.info("\n" + "-"*30) | |
| logger.info("ANALYZING SECTION") | |
| logger.info("-"*30) | |
| logger.info(f"Headline: {headline}") | |
| logger.info(f"Section length: {len(section)} characters") | |
| # Download NLTK data if needed | |
| try: | |
| nltk.data.find('tokenizers/punkt') | |
| except LookupError: | |
| nltk.download('punkt') | |
| sentences = sent_tokenize(section) | |
| logger.info(f"Found {len(sentences)} sentences in section") | |
| if not sentences: | |
| logger.warning("No sentences found in section") | |
| return { | |
| "accuracy_score": 50.0, | |
| "flagged_phrases": [], | |
| "detailed_scores": { | |
| "nli": {"ENTAILMENT": 0.0, "CONTRADICTION": 0.0, "NEUTRAL": 1.0}, | |
| "sensationalism": {"factual reporting": 0.5, "accurate headline": 0.5} | |
| } | |
| } | |
| # Categories for sensationalism check | |
| sensationalism_categories = [ | |
| "clickbait", | |
| "sensationalized", | |
| "misleading", | |
| "factual reporting", | |
| "accurate headline" | |
| ] | |
| logger.info("Checking headline for sensationalism...") | |
| sensationalism_result = self.zero_shot( | |
| headline, | |
| sensationalism_categories, | |
| multi_label=True | |
| ) | |
| sensationalism_scores = { | |
| label: score | |
| for label, score in zip(sensationalism_result['labels'], sensationalism_result['scores']) | |
| } | |
| logger.info(f"Sensationalism scores: {sensationalism_scores}") | |
| # Filter relevant sentences (longer than 20 chars) | |
| relevant_sentences = [s.strip() for s in sentences if len(s.strip()) > 20] | |
| logger.info(f"Found {len(relevant_sentences)} relevant sentences after filtering") | |
| if not relevant_sentences: | |
| logger.warning("No relevant sentences found in section") | |
| return { | |
| "accuracy_score": 50.0, | |
| "flagged_phrases": [], | |
| "detailed_scores": { | |
| "nli": {"ENTAILMENT": 0.0, "CONTRADICTION": 0.0, "NEUTRAL": 1.0}, | |
| "sensationalism": sensationalism_scores | |
| } | |
| } | |
| # Process sentences in batches for contradiction/support | |
| nli_scores = [] | |
| flagged_phrases = [] | |
| batch_size = 8 | |
| logger.info("Processing sentences for contradictions...") | |
| for i in range(0, len(relevant_sentences), batch_size): | |
| batch = relevant_sentences[i:i+batch_size] | |
| batch_inputs = [f"{headline} [SEP] {sentence}" for sentence in batch] | |
| try: | |
| # Get NLI scores for batch | |
| batch_results = self.nli_pipeline(batch_inputs, top_k=None) | |
| if not isinstance(batch_results, list): | |
| batch_results = [batch_results] | |
| for sentence, result in zip(batch, batch_results): | |
| scores = {item['label']: item['score'] for item in result} | |
| nli_scores.append(scores) | |
| # Flag contradictory content with lower threshold | |
| if scores.get('CONTRADICTION', 0) > 0.3: # Lowered threshold | |
| logger.info(f"Found contradictory sentence (score: {scores['CONTRADICTION']:.2f}): {sentence}") | |
| flagged_phrases.append({ | |
| 'text': sentence, | |
| 'type': 'Contradiction', | |
| 'score': scores['CONTRADICTION'], | |
| 'highlight': f"[CONTRADICTION] (Score: {round(scores['CONTRADICTION'] * 100, 1)}%) \"{sentence}\"" | |
| }) | |
| # Flag highly sensationalized content | |
| if sensationalism_scores.get('sensationalized', 0) > 0.6 or sensationalism_scores.get('clickbait', 0) > 0.6: | |
| logger.info(f"Found sensationalized content: {sentence}") | |
| flagged_phrases.append({ | |
| 'text': sentence, | |
| 'type': 'Sensationalized', | |
| 'score': max(sensationalism_scores.get('sensationalized', 0), sensationalism_scores.get('clickbait', 0)), | |
| 'highlight': f"[SENSATIONALIZED] \"{sentence}\"" | |
| }) | |
| except Exception as batch_error: | |
| logger.warning(f"Batch processing error: {str(batch_error)}") | |
| continue | |
| # Calculate aggregate scores with validation | |
| if not nli_scores: | |
| logger.warning("No NLI scores available") | |
| avg_scores = {"ENTAILMENT": 0.0, "CONTRADICTION": 0.0, "NEUTRAL": 1.0} | |
| else: | |
| try: | |
| avg_scores = { | |
| label: float(np.mean([ | |
| score.get(label, 0.0) | |
| for score in nli_scores | |
| ])) | |
| for label in ['ENTAILMENT', 'CONTRADICTION', 'NEUTRAL'] | |
| } | |
| logger.info(f"Average NLI scores: {avg_scores}") | |
| except Exception as agg_error: | |
| logger.error(f"Error aggregating NLI scores: {str(agg_error)}") | |
| avg_scores = {"ENTAILMENT": 0.0, "CONTRADICTION": 0.0, "NEUTRAL": 1.0} | |
| # Calculate headline accuracy score with validation | |
| try: | |
| accuracy_components = { | |
| 'entailment': avg_scores.get('ENTAILMENT', 0.0) * 0.4, | |
| 'non_contradiction': (1 - avg_scores.get('CONTRADICTION', 0.0)) * 0.3, | |
| 'non_sensational': ( | |
| sensationalism_scores.get('factual reporting', 0.0) + | |
| sensationalism_scores.get('accurate headline', 0.0) | |
| ) * 0.15, | |
| 'non_clickbait': ( | |
| 1 - sensationalism_scores.get('clickbait', 0.0) - | |
| sensationalism_scores.get('sensationalized', 0.0) | |
| ) * 0.15 | |
| } | |
| logger.info(f"Accuracy components: {accuracy_components}") | |
| accuracy_score = sum(accuracy_components.values()) * 100 | |
| # Validate final score | |
| if np.isnan(accuracy_score) or not np.isfinite(accuracy_score): | |
| logger.warning("Invalid accuracy score calculated, using default") | |
| accuracy_score = 50.0 | |
| else: | |
| accuracy_score = float(accuracy_score) | |
| logger.info(f"Final accuracy score: {accuracy_score:.1f}") | |
| except Exception as score_error: | |
| logger.error(f"Error calculating accuracy score: {str(score_error)}") | |
| accuracy_score = 50.0 | |
| # Sort and limit flagged phrases | |
| sorted_phrases = sorted( | |
| flagged_phrases, | |
| key=lambda x: x['score'], | |
| reverse=True | |
| ) | |
| unique_phrases = [] | |
| seen = set() | |
| for phrase in sorted_phrases: | |
| if phrase['text'] not in seen: | |
| unique_phrases.append(phrase) | |
| seen.add(phrase['text']) | |
| if len(unique_phrases) >= 5: | |
| break | |
| logger.info(f"Final number of flagged phrases: {len(unique_phrases)}") | |
| return { | |
| "accuracy_score": accuracy_score, | |
| "flagged_phrases": unique_phrases, | |
| "detailed_scores": { | |
| "nli": avg_scores, | |
| "sensationalism": sensationalism_scores | |
| } | |
| } | |
| except Exception as e: | |
| logger.error(f"Section analysis failed: {str(e)}") | |
| return { | |
| "accuracy_score": 50.0, | |
| "flagged_phrases": [], | |
| "detailed_scores": { | |
| "nli": {"ENTAILMENT": 0.0, "CONTRADICTION": 0.0, "NEUTRAL": 1.0}, | |
| "sensationalism": {} | |
| } | |
| } | |
| def _analyze_traditional(self, headline: str, content: str) -> Dict[str, Any]: | |
| """Traditional headline analysis method.""" | |
| try: | |
| # Download NLTK data if needed | |
| try: | |
| nltk.data.find('tokenizers/punkt') | |
| except LookupError: | |
| nltk.download('punkt') | |
| # Basic metrics | |
| headline_words = set(headline.lower().split()) | |
| content_words = set(content.lower().split()) | |
| # Calculate word overlap | |
| overlap_words = headline_words.intersection(content_words) | |
| overlap_score = len(overlap_words) / len(headline_words) if headline_words else 0 | |
| # Check for clickbait patterns | |
| clickbait_patterns = [ | |
| "you won't believe", | |
| "shocking", | |
| "mind blowing", | |
| "amazing", | |
| "incredible", | |
| "unbelievable", | |
| "must see", | |
| "click here", | |
| "find out", | |
| "what happens next" | |
| ] | |
| clickbait_count = sum(1 for pattern in clickbait_patterns if pattern in headline.lower()) | |
| clickbait_penalty = clickbait_count * 10 # 10% penalty per clickbait phrase | |
| # Calculate final score (0-100) | |
| base_score = overlap_score * 100 | |
| final_score = max(0, min(100, base_score - clickbait_penalty)) | |
| # Find potentially misleading phrases | |
| flagged_phrases = [] | |
| sentences = sent_tokenize(content) | |
| for sentence in sentences: | |
| # Flag sentences that directly contradict headline words | |
| sentence_words = set(sentence.lower().split()) | |
| if len(headline_words.intersection(sentence_words)) > 2: | |
| flagged_phrases.append(sentence.strip()) | |
| # Flag sentences with clickbait patterns | |
| if any(pattern in sentence.lower() for pattern in clickbait_patterns): | |
| flagged_phrases.append(sentence.strip()) | |
| return { | |
| "headline_vs_content_score": round(final_score, 1), | |
| "flagged_phrases": list(set(flagged_phrases))[:5] # Limit to top 5 unique phrases | |
| } | |
| except Exception as e: | |
| logger.error(f"Traditional analysis failed: {str(e)}") | |
| return { | |
| "headline_vs_content_score": 0, | |
| "flagged_phrases": [] | |
| } | |
| def analyze(self, headline: str, content: str) -> Dict[str, Any]: | |
| """Analyze how well the headline matches the content.""" | |
| try: | |
| logger.info("\n" + "="*50) | |
| logger.info("HEADLINE ANALYSIS STARTED") | |
| logger.info("="*50) | |
| if not headline.strip() or not content.strip(): | |
| logger.warning("Empty headline or content provided") | |
| return { | |
| "headline_vs_content_score": 0, | |
| "flagged_phrases": [] | |
| } | |
| # Use LLM analysis if available and enabled | |
| if self.use_ai and self.llm_available: | |
| logger.info("Using LLM analysis for headline") | |
| # Split content if needed | |
| sections = self._split_content(headline, content) | |
| section_results = [] | |
| # Analyze each section | |
| for section in sections: | |
| result = self._analyze_section(headline, section) | |
| section_results.append(result) | |
| # Aggregate results across sections | |
| accuracy_scores = [r['accuracy_score'] for r in section_results] | |
| final_score = np.mean(accuracy_scores) | |
| # Combine and deduplicate flagged phrases | |
| all_phrases = [] | |
| for result in section_results: | |
| if 'flagged_phrases' in result: | |
| all_phrases.extend(result['flagged_phrases']) | |
| # Sort by score and get unique phrases | |
| sorted_phrases = sorted(all_phrases, key=lambda x: x['score'], reverse=True) | |
| unique_phrases = [] | |
| seen = set() | |
| for phrase in sorted_phrases: | |
| if phrase['text'] not in seen: | |
| unique_phrases.append(phrase) | |
| seen.add(phrase['text']) | |
| if len(unique_phrases) >= 5: | |
| break | |
| return { | |
| "headline_vs_content_score": round(final_score, 1), | |
| "flagged_phrases": unique_phrases | |
| } | |
| else: | |
| # Use traditional analysis | |
| logger.info("Using traditional headline analysis") | |
| return self._analyze_traditional(headline, content) | |
| except Exception as e: | |
| logger.error(f"Headline analysis failed: {str(e)}") | |
| return { | |
| "headline_vs_content_score": 0, | |
| "flagged_phrases": [] | |
| } |