Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import numpy as np | |
| import re | |
| from urllib.parse import urlparse | |
| import hashlib | |
| import os | |
| # Multi-Model Configuration | |
| MODELS = { | |
| "primary": "cybersectony/phishing-email-detection-distilbert_v2.4.1", | |
| "secondary": "microsoft/DialoGPT-medium", # Fallback for context | |
| "url_specialist": "cybersectony/phishing-email-detection-distilbert_v2.4.1" # URL-focused | |
| } | |
| # Global model storage | |
| models = {} | |
| tokenizers = {} | |
| class AdvancedPhishingDetector: | |
| def __init__(self): | |
| self.load_models() | |
| def load_models(self): | |
| """Load multiple models for ensemble prediction""" | |
| global models, tokenizers | |
| try: | |
| for name, model_path in MODELS.items(): | |
| if name == "secondary": | |
| continue # Skip for now, use primary model | |
| tokenizers[name] = AutoTokenizer.from_pretrained(model_path) | |
| models[name] = AutoModelForSequenceClassification.from_pretrained(model_path) | |
| models[name].eval() | |
| return True | |
| except Exception as e: | |
| print(f"Error loading models: {e}") | |
| return False | |
| def extract_features(self, text): | |
| """Extract hand-crafted features for bias reduction""" | |
| features = {} | |
| # URL features | |
| urls = re.findall(r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+', text) | |
| features['url_count'] = len(urls) | |
| features['has_suspicious_domains'] = any( | |
| domain in url.lower() for url in urls | |
| for domain in ['bit.ly', 'tinyurl', 'shorturl', 'suspicious', 'phish', 'scam'] | |
| ) | |
| # Text pattern features | |
| features['urgency_words'] = len(re.findall(r'urgent|immediate|expire|suspend|verify|confirm|click|act now', text.lower())) | |
| features['money_mentions'] = len(re.findall(r'\$|money|payment|refund|prize|winner|lottery', text.lower())) | |
| features['personal_info_requests'] = len(re.findall(r'password|ssn|social security|credit card|pin|account', text.lower())) | |
| features['spelling_errors'] = self.count_potential_errors(text) | |
| features['excessive_caps'] = len(re.findall(r'[A-Z]{3,}', text)) | |
| # Sender authenticity indicators | |
| features['generic_greetings'] = 1 if re.search(r'^(dear (customer|user|sir|madam))', text.lower()) else 0 | |
| features['email_length'] = len(text) | |
| features['has_attachments'] = 1 if 'attachment' in text.lower() else 0 | |
| return features | |
| def count_potential_errors(self, text): | |
| """Simple heuristic for spelling errors""" | |
| # Look for common phishing misspellings | |
| errors = re.findall(r'recieve|occured|seperate|definately|goverment|secruity|varify', text.lower()) | |
| return len(errors) | |
| def get_model_predictions(self, text): | |
| """Get predictions from multiple models""" | |
| predictions = {} | |
| for model_name in ['primary', 'url_specialist']: | |
| if model_name not in models: | |
| continue | |
| try: | |
| inputs = tokenizers[model_name]( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=512, | |
| padding=True | |
| ) | |
| with torch.no_grad(): | |
| outputs = models[model_name](**inputs) | |
| probs = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| predictions[model_name] = probs[0].tolist() | |
| except Exception as e: | |
| print(f"Error with model {model_name}: {e}") | |
| predictions[model_name] = [0.5, 0.5, 0.0, 0.0] # Default neutral | |
| return predictions | |
| def ensemble_predict(self, text): | |
| """Advanced ensemble prediction with feature weighting""" | |
| # Get model predictions | |
| model_preds = self.get_model_predictions(text) | |
| # Extract hand-crafted features | |
| features = self.extract_features(text) | |
| # Calculate feature-based risk score | |
| risk_score = self.calculate_risk_score(features) | |
| # Ensemble combination | |
| if len(model_preds) == 0: | |
| return self.fallback_prediction(features) | |
| # Weight model predictions | |
| weights = {'primary': 0.7, 'url_specialist': 0.3} | |
| ensemble_probs = [0.0, 0.0, 0.0, 0.0] | |
| total_weight = 0 | |
| for model_name, probs in model_preds.items(): | |
| weight = weights.get(model_name, 0.5) | |
| total_weight += weight | |
| for i in range(len(probs)): | |
| ensemble_probs[i] += probs[i] * weight | |
| # Normalize | |
| if total_weight > 0: | |
| ensemble_probs = [p / total_weight for p in ensemble_probs] | |
| # Adjust with feature-based risk | |
| ensemble_probs = self.adjust_with_features(ensemble_probs, risk_score) | |
| return ensemble_probs, features, risk_score | |
| def calculate_risk_score(self, features): | |
| """Calculate risk score from hand-crafted features""" | |
| score = 0 | |
| # URL-based risk | |
| score += features['url_count'] * 0.1 | |
| score += features['has_suspicious_domains'] * 0.3 | |
| # Content-based risk | |
| score += min(features['urgency_words'] * 0.15, 0.4) | |
| score += min(features['money_mentions'] * 0.1, 0.3) | |
| score += min(features['personal_info_requests'] * 0.2, 0.5) | |
| score += min(features['spelling_errors'] * 0.1, 0.2) | |
| score += min(features['excessive_caps'] * 0.05, 0.15) | |
| # Generic patterns | |
| score += features['generic_greetings'] * 0.1 | |
| return min(score, 1.0) # Cap at 1.0 | |
| def adjust_with_features(self, probs, risk_score): | |
| """Adjust model predictions with feature-based risk""" | |
| adjusted = probs.copy() | |
| # If high risk score, increase phishing probabilities | |
| if risk_score > 0.5: | |
| phishing_boost = risk_score * 0.3 | |
| adjusted[1] += phishing_boost # Phishing URL | |
| adjusted[3] += phishing_boost # Phishing Email | |
| # Reduce legitimate probabilities | |
| adjusted[0] = max(0, adjusted[0] - phishing_boost/2) | |
| adjusted[2] = max(0, adjusted[2] - phishing_boost/2) | |
| # Normalize to ensure sum = 1 | |
| total = sum(adjusted) | |
| if total > 0: | |
| adjusted = [p / total for p in adjusted] | |
| return adjusted | |
| def fallback_prediction(self, features): | |
| """Fallback prediction when models fail""" | |
| risk_score = self.calculate_risk_score(features) | |
| if risk_score > 0.7: | |
| return [0.1, 0.4, 0.1, 0.4], features, risk_score # High phishing | |
| elif risk_score > 0.4: | |
| return [0.3, 0.2, 0.3, 0.2], features, risk_score # Medium risk | |
| else: | |
| return [0.45, 0.05, 0.45, 0.05], features, risk_score # Low risk | |
| # Initialize detector | |
| detector = AdvancedPhishingDetector() | |
| def advanced_predict_phishing(text): | |
| """Advanced phishing prediction with ensemble and feature analysis""" | |
| if not text.strip(): | |
| return "Please enter some text to analyze", {}, "" | |
| try: | |
| # Get ensemble prediction | |
| probs, features, risk_score = detector.ensemble_predict(text) | |
| # Create label mapping | |
| labels = { | |
| "Legitimate Email": probs[0], | |
| "Phishing URL": probs[1], | |
| "Legitimate URL": probs[2], | |
| "Phishing Email": probs[3] | |
| } | |
| # Find primary classification | |
| max_label = max(labels.items(), key=lambda x: x[1]) | |
| prediction = max_label[0] | |
| confidence = max_label[1] | |
| # Enhanced risk assessment | |
| if "Phishing" in prediction and confidence > 0.8: | |
| risk_level = "π¨ HIGH RISK - Strong Phishing Indicators" | |
| risk_color = "red" | |
| elif "Phishing" in prediction or risk_score > 0.5: | |
| risk_level = "β οΈ MEDIUM RISK - Suspicious Patterns Detected" | |
| risk_color = "orange" | |
| elif risk_score > 0.3: | |
| risk_level = "β‘ LOW-MEDIUM RISK - Some Concerns" | |
| risk_color = "yellow" | |
| else: | |
| risk_level = "β LOW RISK - Appears Legitimate" | |
| risk_color = "green" | |
| # Feature analysis summary | |
| feature_alerts = [] | |
| if features['has_suspicious_domains']: | |
| feature_alerts.append("Suspicious domain detected") | |
| if features['urgency_words'] > 2: | |
| feature_alerts.append("High urgency language") | |
| if features['personal_info_requests'] > 1: | |
| feature_alerts.append("Requests personal information") | |
| if features['spelling_errors'] > 0: | |
| feature_alerts.append("Potential spelling errors") | |
| # Format detailed result | |
| result = f""" | |
| ### {risk_level} | |
| **Primary Classification:** {prediction} | |
| **Confidence:** {confidence:.1%} | |
| **Feature Risk Score:** {risk_score:.2f}/1.00 | |
| **Analysis Alerts:** | |
| {chr(10).join(f"β’ {alert}" for alert in feature_alerts) if feature_alerts else "β’ No significant risk patterns detected"} | |
| **Technical Details:** | |
| β’ URLs found: {features['url_count']} | |
| β’ Urgency indicators: {features['urgency_words']} | |
| β’ Personal info requests: {features['personal_info_requests']} | |
| """ | |
| # Confidence breakdown for display (raw floats for gr.Label) | |
| confidence_data = {label: prob for label, prob in labels.items()} | |
| return result, confidence_data, risk_color | |
| except Exception as e: | |
| return f"Error during analysis: {str(e)}", {}, "orange" | |
| # Enhanced Gradio Interface | |
| with gr.Blocks( | |
| theme=gr.themes.Soft(), | |
| title="EmailGuard - Advanced Phishing Detection", | |
| css=""" | |
| .risk-high { color: #dc2626 !important; font-weight: bold; } | |
| .risk-low { color: #16a34a !important; font-weight: bold; } | |
| .main-container { max-width: 900px; margin: 0 auto; } | |
| .feature-box { background: #f8f9fa; padding: 15px; border-radius: 8px; margin: 10px 0; } | |
| """ | |
| ) as demo: | |
| gr.Markdown(""" | |
| # π‘οΈ EmailGuard2 - Advanced AI Phishing Detection | |
| **Multi-Model Ensemble System with Feature Analysis** | |
| β¨ **Enhanced Accuracy** β’ π **Deep Pattern Analysis** β’ π **Real-time Results** | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| input_text = gr.Textbox( | |
| label="π§ Email Content, URL, or Suspicious Message", | |
| placeholder="Paste your email content, suspicious URL, or any text message here for comprehensive analysis...", | |
| lines=10, | |
| max_lines=20 | |
| ) | |
| with gr.Row(): | |
| analyze_btn = gr.Button( | |
| "π Advanced Analysis", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| clear_btn = gr.Button("ποΈ Clear", variant="secondary") | |
| with gr.Column(scale=1): | |
| result_output = gr.Markdown(label="π Analysis Results") | |
| confidence_output = gr.Label( | |
| label="π― Confidence Breakdown", | |
| num_top_classes=4 | |
| ) | |
| # Enhanced examples | |
| gr.Markdown("### π Test These Examples:") | |
| examples = [ | |
| ["URGENT: Your PayPal account has been limited! Verify immediately at http://paypal-security-check.suspicious.com/verify or lose access forever!"], | |
| ["Hi Mufasa, Thanks for sending the quarterly report. I've reviewed the numbers and they look good. Let's discuss in tomorrow's meeting. Best, Simba"], | |
| ["π CONGRATULATIONS, Chinno! You've won $50,000! Click here to claim: bit.ly/winner123. Act fast, expires in 24hrs! Reply with SSN to confirm."], | |
| ["Your Microsoft Office subscription expires tomorrow. Renew now to avoid service interruption. Visit: https://office.microsoft.com/renew"], | |
| ["Dear Valued Customer, We detected unusual activity on your account. Please verify your identity by clicking the link below and entering your password."], | |
| ["Meeting reminder: Team standup at 10 AM in conference room A, Y4C Hub. Please bring your project updates. Thanks!"] | |
| ] | |
| gr.Examples( | |
| examples=examples, | |
| inputs=input_text, | |
| outputs=[result_output, confidence_output] | |
| ) | |
| # Event handlers | |
| analyze_btn.click( | |
| fn=advanced_predict_phishing, | |
| inputs=input_text, | |
| outputs=[result_output, confidence_output, gr.State()] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: ("", "", {}), | |
| outputs=[input_text, result_output, confidence_output] | |
| ) | |
| input_text.submit( | |
| fn=advanced_predict_phishing, | |
| inputs=input_text, | |
| outputs=[result_output, confidence_output, gr.State()] | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ### π¬ Advanced Detection Features | |
| **π€ Multi-Model Ensemble:** Combines predictions from specialized models | |
| **π― Feature Engineering:** Hand-crafted rules for pattern detection | |
| **βοΈ Bias Reduction:** Multiple validation layers prevent false positives | |
| **π Risk Scoring:** Comprehensive analysis beyond simple classification | |
| **π URL Analysis:** Specialized detection for malicious links | |
| **π Content Analysis:** Deep text pattern recognition | |
| ### β‘ What Makes This More Accurate: | |
| - **Ensemble Learning:** Multiple models vote on final decision | |
| - **Feature Fusion:** AI + Rule-based detection combined | |
| - **Adaptive Thresholds:** Dynamic risk assessment | |
| - **Comprehensive Coverage:** Email, URL, and text message analysis | |
| **β οΈ Academic Research Tool:** For educational purposes - always verify through official channels. | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| share=False, | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True | |
| ) |