Spaces:
Sleeping
Sleeping
File size: 11,247 Bytes
c11fb0c 5d1d43b c11fb0c 5d1d43b c11fb0c 5d1d43b c11fb0c 875a2e3 c11fb0c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 |
"""Gradio demo for DGA domain classifier on HuggingFace Spaces.
This demo loads the model from HuggingFace Hub and provides an interactive
interface for classifying domains as legitimate or DGA-generated.
"""
import torch
import gradio as gr
# Import custom model and encoding
from model import DGAEncoderForSequenceClassification
from charset import encode_domain
# Load model from HuggingFace Hub
MODEL_NAME = "ccss17/dga-transformer-encoder"
print(f"Loading model from {MODEL_NAME}...")
model = DGAEncoderForSequenceClassification.from_pretrained(MODEL_NAME)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(f"Model loaded on {device}")
def predict_domain(domain: str):
"""Classify a domain as legitimate or DGA-generated.
Args:
domain: Domain name to classify (e.g., "google.com", "xjkd8f2h.com")
Returns:
tuple: (prediction_label, confidence_score, html_output)
"""
if not domain or not domain.strip():
return "β οΈ Invalid Input", "", "Please enter a domain name."
domain = domain.strip().lower()
# Encode domain to token IDs
input_ids = torch.tensor(
[encode_domain(domain, max_len=64)], device=device
)
# Get model prediction
with torch.no_grad():
outputs = model(input_ids=input_ids)
logits = outputs.logits
probs = torch.softmax(logits, dim=-1)
pred_class = torch.argmax(probs, dim=-1).item()
confidence = probs[0, pred_class].item()
# Format output
label_names = ["β
Legitimate", "π¨ DGA (Malicious)"]
prediction = label_names[pred_class]
confidence_pct = f"{confidence * 100:.2f}%"
# Create detailed HTML output
legit_prob = probs[0, 0].item()
dga_prob = probs[0, 1].item()
html_output = f"""
<div style="padding: 20px; border-radius: 10px; background: {"#d4edda" if pred_class == 0 else "#f8d7da"};">
<h2 style="margin-top: 0; color: {"#155724" if pred_class == 0 else "#721c24"};">
{prediction}
</h2>
<p style="font-size: 18px; margin: 10px 0;">
<strong>Domain:</strong> <code>{domain}</code>
</p>
<p style="font-size: 18px; margin: 10px 0;">
<strong>Confidence:</strong> {confidence_pct}
</p>
<hr style="margin: 15px 0; border: none; border-top: 1px solid #ccc;">
<h3>Probability Breakdown:</h3>
<div style="margin: 10px 0;">
<div style="display: flex; justify-content: space-between; margin-bottom: 5px;">
<span>β
Legitimate:</span>
<strong>{legit_prob * 100:.2f}%</strong>
</div>
<div style="background: #e9ecef; border-radius: 5px; overflow: hidden; height: 20px;">
<div style="background: #28a745; height: 100%; width: {legit_prob * 100}%;"></div>
</div>
</div>
<div style="margin: 10px 0;">
<div style="display: flex; justify-content: space-between; margin-bottom: 5px;">
<span>π¨ DGA (Malicious):</span>
<strong>{dga_prob * 100:.2f}%</strong>
</div>
<div style="background: #e9ecef; border-radius: 5px; overflow: hidden; height: 20px;">
<div style="background: #dc3545; height: 100%; width: {dga_prob * 100}%;"></div>
</div>
</div>
</div>
"""
return prediction, confidence_pct, html_output
def predict_batch(domains_text: str):
"""Classify multiple domains at once.
Args:
domains_text: Newline-separated list of domains
Returns:
str: Formatted results table in Markdown
"""
if not domains_text or not domains_text.strip():
return "Please enter one or more domain names (one per line)."
domains = [
d.strip() for d in domains_text.strip().split("\n") if d.strip()
]
if not domains:
return "No valid domains provided."
results = []
results.append("| Domain | Prediction | Confidence |")
results.append("|--------|------------|------------|")
for domain in domains:
domain = domain.lower()
# Encode domain
input_ids = torch.tensor(
[encode_domain(domain, max_len=64)], device=device
)
# Predict
with torch.no_grad():
outputs = model(input_ids=input_ids)
logits = outputs.logits
probs = torch.softmax(logits, dim=-1)
pred_class = torch.argmax(probs, dim=-1).item()
confidence = probs[0, pred_class].item()
label = "Legitimate β
" if pred_class == 0 else "DGA π¨"
confidence_pct = f"{confidence * 100:.1f}%"
results.append(f"| `{domain}` | {label} | {confidence_pct} |")
return "\n".join(results)
# Example domains for quick testing
EXAMPLES = [
["google.com"],
["github.com"],
["stackoverflow.com"],
["xjkd8f2h.com"],
["qwfp93nx.net"],
["h4fk29fd.org"],
["facebook.com"],
["fjdkslajf.com"],
]
BATCH_EXAMPLES = [
"google.com\ngithub.com\nstackoverflow.com",
"xjkd8f2h.com\nqwfp93nx.net\nh4fk29fd.org",
"amazon.com\nfacebook.com\ntwitter.com\nmicrosoft.com",
]
# Create Gradio interface
with gr.Blocks(title="DGA Domain Classifier", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# π DGA Domain Classifier
**Detect malicious domains generated by Domain Generation Algorithms (DGAs)**
This model uses a transformer-based neural network to classify domains as either:
- β
**Legitimate**: Normal, human-registered domains
- π¨ **DGA (Malicious)**: Algorithmically-generated domains used by malware for C2 communication
---
### Model Details
- **Architecture**: Custom Transformer Encoder (4 layers, 256 dim)
- **Parameters**: 3.2M
- **Accuracy**: 96.78% F1 score on test set
- **Inference Speed**: <1ms per domain
---
""")
with gr.Tabs():
# Tab 1: Single domain prediction
with gr.Tab("Single Domain"):
gr.Markdown("### Test a single domain name")
with gr.Row():
with gr.Column(scale=2):
domain_input = gr.Textbox(
label="Enter Domain Name",
placeholder="e.g., google.com, xjkd8f2h.com",
lines=1,
)
predict_btn = gr.Button(
"π Classify Domain", variant="primary", size="lg"
)
with gr.Column(scale=1):
prediction_output = gr.Textbox(label="Prediction", lines=1)
confidence_output = gr.Textbox(label="Confidence", lines=1)
html_output = gr.HTML(label="Detailed Results")
predict_btn.click(
fn=predict_domain,
inputs=domain_input,
outputs=[prediction_output, confidence_output, html_output],
)
gr.Markdown("### Try these examples:")
gr.Examples(
examples=EXAMPLES,
inputs=domain_input,
outputs=[prediction_output, confidence_output, html_output],
fn=predict_domain,
cache_examples=False,
)
# Tab 2: Batch prediction
with gr.Tab("Batch Prediction"):
gr.Markdown("### Classify multiple domains at once (one per line)")
batch_input = gr.Textbox(
label="Enter Domains (one per line)",
placeholder="google.com\ngithub.com\nxjkd8f2h.com",
lines=8,
)
batch_btn = gr.Button(
"π Classify All Domains", variant="primary", size="lg"
)
batch_output = gr.Markdown(label="Results")
batch_btn.click(
fn=predict_batch,
inputs=batch_input,
outputs=batch_output,
)
gr.Markdown("### Try these examples:")
gr.Examples(
examples=BATCH_EXAMPLES,
inputs=batch_input,
outputs=batch_output,
fn=predict_batch,
cache_examples=False,
)
# Tab 3: About
with gr.Tab("About"):
gr.Markdown("""
## π What are DGAs?
**Domain Generation Algorithms (DGAs)** are techniques used by malware to generate
large numbers of pseudo-random domain names for C2 (command-and-control) communication.
### Why DGAs are dangerous:
- **Evasion**: Traditional blacklists can't keep up with thousands of generated domains
- **Resilience**: Even if some domains are blocked, malware can try others
- **Stealth**: DGA domains look random, making detection challenging
### How this model works:
1. **Character-level tokenization**: Breaks domain into individual characters
2. **Transformer encoder**: Learns patterns in character sequences
3. **Self-attention**: Detects unusual character combinations (e.g., `xqz`, `fgh`)
4. **Classification**: Predicts if domain is legitimate or DGA-generated
### Key Features:
- **High accuracy**: 96.78% F1 score on test set
- **Fast inference**: <1ms per domain (GPU) or ~10ms (CPU)
- **Lightweight**: Only 3.2M parameters
- **Production-ready**: Trained on real-world malware domains
### Examples:
**Legitimate domains** (structured, pronounceable):
- `google.com`, `github.com`, `stackoverflow.com`
- `api-docs.company.com`, `cdn-assets.example.org`
**DGA domains** (random, unpronounceable):
- `xjkd8f2h.com`, `qwfp93nx.net`, `h4fk29fd.org`
- `kdjf92jd.info`, `zmxbv73k.biz`
---
### Technical Details:
- **Model**: Custom Transformer Encoder
- **Training data**: ExtraHop DGA dataset
- **Framework**: PyTorch + HuggingFace Transformers
- **Experiment tracking**: Weights & Biases
### Links:
- [Model Card](https://huggingface.co/ccss17/dga-transformer-encoder)
- [GitHub Repository](https://github.com/ccss17/DGA-Transformer-Encoder)
- [ExtraHop DGA Dataset](https://github.com/extrahop/dga-training-data)
---
**Built with β€οΈ using PyTorch, HuggingFace, and Gradio**
""")
gr.Markdown("""
---
<div style="text-align: center; color: #666; font-size: 14px;">
<p>β οΈ <strong>Disclaimer</strong>: This model is for educational and research purposes.
Always use multiple detection methods in production security systems.</p>
<p>Model accuracy: 96.78% | False positive rate: ~3%</p>
</div>
""")
if __name__ == "__main__":
demo.launch()
|