"""Gradio demo for DGA domain classifier on HuggingFace Spaces.
This demo loads the model from HuggingFace Hub and provides an interactive
interface for classifying domains as legitimate or DGA-generated.
"""
import torch
import gradio as gr
# Import custom model and encoding
from model import DGAEncoderForSequenceClassification
from charset import encode_domain
# Load model from HuggingFace Hub
MODEL_NAME = "ccss17/dga-transformer-encoder"
print(f"Loading model from {MODEL_NAME}...")
model = DGAEncoderForSequenceClassification.from_pretrained(MODEL_NAME)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(f"Model loaded on {device}")
def predict_domain(domain: str):
"""Classify a domain as legitimate or DGA-generated.
Args:
domain: Domain name to classify (e.g., "google.com", "xjkd8f2h.com")
Returns:
tuple: (prediction_label, confidence_score, html_output)
"""
if not domain or not domain.strip():
return "⚠️ Invalid Input", "", "Please enter a domain name."
domain = domain.strip().lower()
# Encode domain to token IDs
input_ids = torch.tensor(
[encode_domain(domain, max_len=64)], device=device
)
# Get model prediction
with torch.no_grad():
outputs = model(input_ids=input_ids)
logits = outputs.logits
probs = torch.softmax(logits, dim=-1)
pred_class = torch.argmax(probs, dim=-1).item()
confidence = probs[0, pred_class].item()
# Format output
label_names = ["✅ Legitimate", "🚨 DGA (Malicious)"]
prediction = label_names[pred_class]
confidence_pct = f"{confidence * 100:.2f}%"
# Create detailed HTML output
legit_prob = probs[0, 0].item()
dga_prob = probs[0, 1].item()
html_output = f"""
{prediction}
Domain: {domain}
Confidence: {confidence_pct}
Probability Breakdown:
✅ Legitimate:
{legit_prob * 100:.2f}%
🚨 DGA (Malicious):
{dga_prob * 100:.2f}%
"""
return prediction, confidence_pct, html_output
def predict_batch(domains_text: str):
"""Classify multiple domains at once.
Args:
domains_text: Newline-separated list of domains
Returns:
str: Formatted results table in Markdown
"""
if not domains_text or not domains_text.strip():
return "Please enter one or more domain names (one per line)."
domains = [
d.strip() for d in domains_text.strip().split("\n") if d.strip()
]
if not domains:
return "No valid domains provided."
results = []
results.append("| Domain | Prediction | Confidence |")
results.append("|--------|------------|------------|")
for domain in domains:
domain = domain.lower()
# Encode domain
input_ids = torch.tensor(
[encode_domain(domain, max_len=64)], device=device
)
# Predict
with torch.no_grad():
outputs = model(input_ids=input_ids)
logits = outputs.logits
probs = torch.softmax(logits, dim=-1)
pred_class = torch.argmax(probs, dim=-1).item()
confidence = probs[0, pred_class].item()
label = "Legitimate ✅" if pred_class == 0 else "DGA 🚨"
confidence_pct = f"{confidence * 100:.1f}%"
results.append(f"| `{domain}` | {label} | {confidence_pct} |")
return "\n".join(results)
# Example domains for quick testing
EXAMPLES = [
["google.com"],
["github.com"],
["stackoverflow.com"],
["xjkd8f2h.com"],
["qwfp93nx.net"],
["h4fk29fd.org"],
["facebook.com"],
["fjdkslajf.com"],
]
BATCH_EXAMPLES = [
"google.com\ngithub.com\nstackoverflow.com",
"xjkd8f2h.com\nqwfp93nx.net\nh4fk29fd.org",
"amazon.com\nfacebook.com\ntwitter.com\nmicrosoft.com",
]
# Create Gradio interface
with gr.Blocks(title="DGA Domain Classifier", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🔍 DGA Domain Classifier
**Detect malicious domains generated by Domain Generation Algorithms (DGAs)**
This model uses a transformer-based neural network to classify domains as either:
- ✅ **Legitimate**: Normal, human-registered domains
- 🚨 **DGA (Malicious)**: Algorithmically-generated domains used by malware for C2 communication
---
### Model Details
- **Architecture**: Custom Transformer Encoder (4 layers, 256 dim)
- **Parameters**: 3.2M
- **Accuracy**: 96.78% F1 score on test set
- **Inference Speed**: <1ms per domain
---
""")
with gr.Tabs():
# Tab 1: Single domain prediction
with gr.Tab("Single Domain"):
gr.Markdown("### Test a single domain name")
with gr.Row():
with gr.Column(scale=2):
domain_input = gr.Textbox(
label="Enter Domain Name",
placeholder="e.g., google.com, xjkd8f2h.com",
lines=1,
)
predict_btn = gr.Button(
"🔍 Classify Domain", variant="primary", size="lg"
)
with gr.Column(scale=1):
prediction_output = gr.Textbox(label="Prediction", lines=1)
confidence_output = gr.Textbox(label="Confidence", lines=1)
html_output = gr.HTML(label="Detailed Results")
predict_btn.click(
fn=predict_domain,
inputs=domain_input,
outputs=[prediction_output, confidence_output, html_output],
)
gr.Markdown("### Try these examples:")
gr.Examples(
examples=EXAMPLES,
inputs=domain_input,
outputs=[prediction_output, confidence_output, html_output],
fn=predict_domain,
cache_examples=False,
)
# Tab 2: Batch prediction
with gr.Tab("Batch Prediction"):
gr.Markdown("### Classify multiple domains at once (one per line)")
batch_input = gr.Textbox(
label="Enter Domains (one per line)",
placeholder="google.com\ngithub.com\nxjkd8f2h.com",
lines=8,
)
batch_btn = gr.Button(
"🔍 Classify All Domains", variant="primary", size="lg"
)
batch_output = gr.Markdown(label="Results")
batch_btn.click(
fn=predict_batch,
inputs=batch_input,
outputs=batch_output,
)
gr.Markdown("### Try these examples:")
gr.Examples(
examples=BATCH_EXAMPLES,
inputs=batch_input,
outputs=batch_output,
fn=predict_batch,
cache_examples=False,
)
# Tab 3: About
with gr.Tab("About"):
gr.Markdown("""
## 📚 What are DGAs?
**Domain Generation Algorithms (DGAs)** are techniques used by malware to generate
large numbers of pseudo-random domain names for C2 (command-and-control) communication.
### Why DGAs are dangerous:
- **Evasion**: Traditional blacklists can't keep up with thousands of generated domains
- **Resilience**: Even if some domains are blocked, malware can try others
- **Stealth**: DGA domains look random, making detection challenging
### How this model works:
1. **Character-level tokenization**: Breaks domain into individual characters
2. **Transformer encoder**: Learns patterns in character sequences
3. **Self-attention**: Detects unusual character combinations (e.g., `xqz`, `fgh`)
4. **Classification**: Predicts if domain is legitimate or DGA-generated
### Key Features:
- **High accuracy**: 96.78% F1 score on test set
- **Fast inference**: <1ms per domain (GPU) or ~10ms (CPU)
- **Lightweight**: Only 3.2M parameters
- **Production-ready**: Trained on real-world malware domains
### Examples:
**Legitimate domains** (structured, pronounceable):
- `google.com`, `github.com`, `stackoverflow.com`
- `api-docs.company.com`, `cdn-assets.example.org`
**DGA domains** (random, unpronounceable):
- `xjkd8f2h.com`, `qwfp93nx.net`, `h4fk29fd.org`
- `kdjf92jd.info`, `zmxbv73k.biz`
---
### Technical Details:
- **Model**: Custom Transformer Encoder
- **Training data**: ExtraHop DGA dataset
- **Framework**: PyTorch + HuggingFace Transformers
- **Experiment tracking**: Weights & Biases
### Links:
- [Model Card](https://huggingface.co/ccss17/dga-transformer-encoder)
- [GitHub Repository](https://github.com/ccss17/DGA-Transformer-Encoder)
- [ExtraHop DGA Dataset](https://github.com/extrahop/dga-training-data)
---
**Built with ❤️ using PyTorch, HuggingFace, and Gradio**
""")
gr.Markdown("""
---
⚠️ Disclaimer: This model is for educational and research purposes.
Always use multiple detection methods in production security systems.
Model accuracy: 96.78% | False positive rate: ~3%
""")
if __name__ == "__main__":
demo.launch()