"""Gradio demo for DGA domain classifier on HuggingFace Spaces. This demo loads the model from HuggingFace Hub and provides an interactive interface for classifying domains as legitimate or DGA-generated. """ import torch import gradio as gr # Import custom model and encoding from model import DGAEncoderForSequenceClassification from charset import encode_domain # Load model from HuggingFace Hub MODEL_NAME = "ccss17/dga-transformer-encoder" print(f"Loading model from {MODEL_NAME}...") model = DGAEncoderForSequenceClassification.from_pretrained(MODEL_NAME) model.eval() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) print(f"Model loaded on {device}") def predict_domain(domain: str): """Classify a domain as legitimate or DGA-generated. Args: domain: Domain name to classify (e.g., "google.com", "xjkd8f2h.com") Returns: tuple: (prediction_label, confidence_score, html_output) """ if not domain or not domain.strip(): return "⚠️ Invalid Input", "", "Please enter a domain name." domain = domain.strip().lower() # Encode domain to token IDs input_ids = torch.tensor( [encode_domain(domain, max_len=64)], device=device ) # Get model prediction with torch.no_grad(): outputs = model(input_ids=input_ids) logits = outputs.logits probs = torch.softmax(logits, dim=-1) pred_class = torch.argmax(probs, dim=-1).item() confidence = probs[0, pred_class].item() # Format output label_names = ["✅ Legitimate", "🚨 DGA (Malicious)"] prediction = label_names[pred_class] confidence_pct = f"{confidence * 100:.2f}%" # Create detailed HTML output legit_prob = probs[0, 0].item() dga_prob = probs[0, 1].item() html_output = f"""

{prediction}

Domain: {domain}

Confidence: {confidence_pct}


Probability Breakdown:

✅ Legitimate: {legit_prob * 100:.2f}%
🚨 DGA (Malicious): {dga_prob * 100:.2f}%
""" return prediction, confidence_pct, html_output def predict_batch(domains_text: str): """Classify multiple domains at once. Args: domains_text: Newline-separated list of domains Returns: str: Formatted results table in Markdown """ if not domains_text or not domains_text.strip(): return "Please enter one or more domain names (one per line)." domains = [ d.strip() for d in domains_text.strip().split("\n") if d.strip() ] if not domains: return "No valid domains provided." results = [] results.append("| Domain | Prediction | Confidence |") results.append("|--------|------------|------------|") for domain in domains: domain = domain.lower() # Encode domain input_ids = torch.tensor( [encode_domain(domain, max_len=64)], device=device ) # Predict with torch.no_grad(): outputs = model(input_ids=input_ids) logits = outputs.logits probs = torch.softmax(logits, dim=-1) pred_class = torch.argmax(probs, dim=-1).item() confidence = probs[0, pred_class].item() label = "Legitimate ✅" if pred_class == 0 else "DGA 🚨" confidence_pct = f"{confidence * 100:.1f}%" results.append(f"| `{domain}` | {label} | {confidence_pct} |") return "\n".join(results) # Example domains for quick testing EXAMPLES = [ ["google.com"], ["github.com"], ["stackoverflow.com"], ["xjkd8f2h.com"], ["qwfp93nx.net"], ["h4fk29fd.org"], ["facebook.com"], ["fjdkslajf.com"], ] BATCH_EXAMPLES = [ "google.com\ngithub.com\nstackoverflow.com", "xjkd8f2h.com\nqwfp93nx.net\nh4fk29fd.org", "amazon.com\nfacebook.com\ntwitter.com\nmicrosoft.com", ] # Create Gradio interface with gr.Blocks(title="DGA Domain Classifier", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🔍 DGA Domain Classifier **Detect malicious domains generated by Domain Generation Algorithms (DGAs)** This model uses a transformer-based neural network to classify domains as either: - ✅ **Legitimate**: Normal, human-registered domains - 🚨 **DGA (Malicious)**: Algorithmically-generated domains used by malware for C2 communication --- ### Model Details - **Architecture**: Custom Transformer Encoder (4 layers, 256 dim) - **Parameters**: 3.2M - **Accuracy**: 96.78% F1 score on test set - **Inference Speed**: <1ms per domain --- """) with gr.Tabs(): # Tab 1: Single domain prediction with gr.Tab("Single Domain"): gr.Markdown("### Test a single domain name") with gr.Row(): with gr.Column(scale=2): domain_input = gr.Textbox( label="Enter Domain Name", placeholder="e.g., google.com, xjkd8f2h.com", lines=1, ) predict_btn = gr.Button( "🔍 Classify Domain", variant="primary", size="lg" ) with gr.Column(scale=1): prediction_output = gr.Textbox(label="Prediction", lines=1) confidence_output = gr.Textbox(label="Confidence", lines=1) html_output = gr.HTML(label="Detailed Results") predict_btn.click( fn=predict_domain, inputs=domain_input, outputs=[prediction_output, confidence_output, html_output], ) gr.Markdown("### Try these examples:") gr.Examples( examples=EXAMPLES, inputs=domain_input, outputs=[prediction_output, confidence_output, html_output], fn=predict_domain, cache_examples=False, ) # Tab 2: Batch prediction with gr.Tab("Batch Prediction"): gr.Markdown("### Classify multiple domains at once (one per line)") batch_input = gr.Textbox( label="Enter Domains (one per line)", placeholder="google.com\ngithub.com\nxjkd8f2h.com", lines=8, ) batch_btn = gr.Button( "🔍 Classify All Domains", variant="primary", size="lg" ) batch_output = gr.Markdown(label="Results") batch_btn.click( fn=predict_batch, inputs=batch_input, outputs=batch_output, ) gr.Markdown("### Try these examples:") gr.Examples( examples=BATCH_EXAMPLES, inputs=batch_input, outputs=batch_output, fn=predict_batch, cache_examples=False, ) # Tab 3: About with gr.Tab("About"): gr.Markdown(""" ## 📚 What are DGAs? **Domain Generation Algorithms (DGAs)** are techniques used by malware to generate large numbers of pseudo-random domain names for C2 (command-and-control) communication. ### Why DGAs are dangerous: - **Evasion**: Traditional blacklists can't keep up with thousands of generated domains - **Resilience**: Even if some domains are blocked, malware can try others - **Stealth**: DGA domains look random, making detection challenging ### How this model works: 1. **Character-level tokenization**: Breaks domain into individual characters 2. **Transformer encoder**: Learns patterns in character sequences 3. **Self-attention**: Detects unusual character combinations (e.g., `xqz`, `fgh`) 4. **Classification**: Predicts if domain is legitimate or DGA-generated ### Key Features: - **High accuracy**: 96.78% F1 score on test set - **Fast inference**: <1ms per domain (GPU) or ~10ms (CPU) - **Lightweight**: Only 3.2M parameters - **Production-ready**: Trained on real-world malware domains ### Examples: **Legitimate domains** (structured, pronounceable): - `google.com`, `github.com`, `stackoverflow.com` - `api-docs.company.com`, `cdn-assets.example.org` **DGA domains** (random, unpronounceable): - `xjkd8f2h.com`, `qwfp93nx.net`, `h4fk29fd.org` - `kdjf92jd.info`, `zmxbv73k.biz` --- ### Technical Details: - **Model**: Custom Transformer Encoder - **Training data**: ExtraHop DGA dataset - **Framework**: PyTorch + HuggingFace Transformers - **Experiment tracking**: Weights & Biases ### Links: - [Model Card](https://huggingface.co/ccss17/dga-transformer-encoder) - [GitHub Repository](https://github.com/ccss17/DGA-Transformer-Encoder) - [ExtraHop DGA Dataset](https://github.com/extrahop/dga-training-data) --- **Built with ❤️ using PyTorch, HuggingFace, and Gradio** """) gr.Markdown(""" ---

⚠️ Disclaimer: This model is for educational and research purposes. Always use multiple detection methods in production security systems.

Model accuracy: 96.78% | False positive rate: ~3%

""") if __name__ == "__main__": demo.launch()