File size: 11,247 Bytes
c11fb0c
 
 
 
 
 
 
 
 
5d1d43b
 
 
c11fb0c
 
 
5d1d43b
c11fb0c
5d1d43b
c11fb0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
875a2e3
 
 
c11fb0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
"""Gradio demo for DGA domain classifier on HuggingFace Spaces.

This demo loads the model from HuggingFace Hub and provides an interactive
interface for classifying domains as legitimate or DGA-generated.
"""

import torch
import gradio as gr

# Import custom model and encoding
from model import DGAEncoderForSequenceClassification
from charset import encode_domain


# Load model from HuggingFace Hub
MODEL_NAME = "ccss17/dga-transformer-encoder"
print(f"Loading model from {MODEL_NAME}...")
model = DGAEncoderForSequenceClassification.from_pretrained(MODEL_NAME)
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(f"Model loaded on {device}")


def predict_domain(domain: str):
    """Classify a domain as legitimate or DGA-generated.

    Args:
        domain: Domain name to classify (e.g., "google.com", "xjkd8f2h.com")

    Returns:
        tuple: (prediction_label, confidence_score, html_output)
    """
    if not domain or not domain.strip():
        return "⚠️ Invalid Input", "", "Please enter a domain name."

    domain = domain.strip().lower()

    # Encode domain to token IDs
    input_ids = torch.tensor(
        [encode_domain(domain, max_len=64)], device=device
    )

    # Get model prediction
    with torch.no_grad():
        outputs = model(input_ids=input_ids)
        logits = outputs.logits
        probs = torch.softmax(logits, dim=-1)
        pred_class = torch.argmax(probs, dim=-1).item()
        confidence = probs[0, pred_class].item()

    # Format output
    label_names = ["βœ… Legitimate", "🚨 DGA (Malicious)"]
    prediction = label_names[pred_class]
    confidence_pct = f"{confidence * 100:.2f}%"

    # Create detailed HTML output
    legit_prob = probs[0, 0].item()
    dga_prob = probs[0, 1].item()

    html_output = f"""
    <div style="padding: 20px; border-radius: 10px; background: {"#d4edda" if pred_class == 0 else "#f8d7da"};">
        <h2 style="margin-top: 0; color: {"#155724" if pred_class == 0 else "#721c24"};">
            {prediction}
        </h2>
        <p style="font-size: 18px; margin: 10px 0;">
            <strong>Domain:</strong> <code>{domain}</code>
        </p>
        <p style="font-size: 18px; margin: 10px 0;">
            <strong>Confidence:</strong> {confidence_pct}
        </p>
        <hr style="margin: 15px 0; border: none; border-top: 1px solid #ccc;">
        <h3>Probability Breakdown:</h3>
        <div style="margin: 10px 0;">
            <div style="display: flex; justify-content: space-between; margin-bottom: 5px;">
                <span>βœ… Legitimate:</span>
                <strong>{legit_prob * 100:.2f}%</strong>
            </div>
            <div style="background: #e9ecef; border-radius: 5px; overflow: hidden; height: 20px;">
                <div style="background: #28a745; height: 100%; width: {legit_prob * 100}%;"></div>
            </div>
        </div>
        <div style="margin: 10px 0;">
            <div style="display: flex; justify-content: space-between; margin-bottom: 5px;">
                <span>🚨 DGA (Malicious):</span>
                <strong>{dga_prob * 100:.2f}%</strong>
            </div>
            <div style="background: #e9ecef; border-radius: 5px; overflow: hidden; height: 20px;">
                <div style="background: #dc3545; height: 100%; width: {dga_prob * 100}%;"></div>
            </div>
        </div>
    </div>
    """

    return prediction, confidence_pct, html_output


def predict_batch(domains_text: str):
    """Classify multiple domains at once.

    Args:
        domains_text: Newline-separated list of domains

    Returns:
        str: Formatted results table in Markdown
    """
    if not domains_text or not domains_text.strip():
        return "Please enter one or more domain names (one per line)."

    domains = [
        d.strip() for d in domains_text.strip().split("\n") if d.strip()
    ]

    if not domains:
        return "No valid domains provided."

    results = []
    results.append("| Domain | Prediction | Confidence |")
    results.append("|--------|------------|------------|")

    for domain in domains:
        domain = domain.lower()

        # Encode domain
        input_ids = torch.tensor(
            [encode_domain(domain, max_len=64)], device=device
        )

        # Predict
        with torch.no_grad():
            outputs = model(input_ids=input_ids)
            logits = outputs.logits
            probs = torch.softmax(logits, dim=-1)
            pred_class = torch.argmax(probs, dim=-1).item()
            confidence = probs[0, pred_class].item()

        label = "Legitimate βœ…" if pred_class == 0 else "DGA 🚨"
        confidence_pct = f"{confidence * 100:.1f}%"

        results.append(f"| `{domain}` | {label} | {confidence_pct} |")

    return "\n".join(results)


# Example domains for quick testing
EXAMPLES = [
    ["google.com"],
    ["github.com"],
    ["stackoverflow.com"],
    ["xjkd8f2h.com"],
    ["qwfp93nx.net"],
    ["h4fk29fd.org"],
    ["facebook.com"],
    ["fjdkslajf.com"],
]

BATCH_EXAMPLES = [
    "google.com\ngithub.com\nstackoverflow.com",
    "xjkd8f2h.com\nqwfp93nx.net\nh4fk29fd.org",
    "amazon.com\nfacebook.com\ntwitter.com\nmicrosoft.com",
]


# Create Gradio interface
with gr.Blocks(title="DGA Domain Classifier", theme=gr.themes.Soft()) as demo:
    gr.Markdown("""
    # πŸ” DGA Domain Classifier
    
    **Detect malicious domains generated by Domain Generation Algorithms (DGAs)**
    
    This model uses a transformer-based neural network to classify domains as either:
    - βœ… **Legitimate**: Normal, human-registered domains
    - 🚨 **DGA (Malicious)**: Algorithmically-generated domains used by malware for C2 communication
    
    ---
    
    ### Model Details
    - **Architecture**: Custom Transformer Encoder (4 layers, 256 dim)
    - **Parameters**: 3.2M
    - **Accuracy**: 96.78% F1 score on test set
    - **Inference Speed**: <1ms per domain
    
    ---
    """)

    with gr.Tabs():
        # Tab 1: Single domain prediction
        with gr.Tab("Single Domain"):
            gr.Markdown("### Test a single domain name")

            with gr.Row():
                with gr.Column(scale=2):
                    domain_input = gr.Textbox(
                        label="Enter Domain Name",
                        placeholder="e.g., google.com, xjkd8f2h.com",
                        lines=1,
                    )
                    predict_btn = gr.Button(
                        "πŸ” Classify Domain", variant="primary", size="lg"
                    )

                with gr.Column(scale=1):
                    prediction_output = gr.Textbox(label="Prediction", lines=1)
                    confidence_output = gr.Textbox(label="Confidence", lines=1)

            html_output = gr.HTML(label="Detailed Results")

            predict_btn.click(
                fn=predict_domain,
                inputs=domain_input,
                outputs=[prediction_output, confidence_output, html_output],
            )

            gr.Markdown("### Try these examples:")
            gr.Examples(
                examples=EXAMPLES,
                inputs=domain_input,
                outputs=[prediction_output, confidence_output, html_output],
                fn=predict_domain,
                cache_examples=False,
            )

        # Tab 2: Batch prediction
        with gr.Tab("Batch Prediction"):
            gr.Markdown("### Classify multiple domains at once (one per line)")

            batch_input = gr.Textbox(
                label="Enter Domains (one per line)",
                placeholder="google.com\ngithub.com\nxjkd8f2h.com",
                lines=8,
            )
            batch_btn = gr.Button(
                "πŸ” Classify All Domains", variant="primary", size="lg"
            )
            batch_output = gr.Markdown(label="Results")

            batch_btn.click(
                fn=predict_batch,
                inputs=batch_input,
                outputs=batch_output,
            )

            gr.Markdown("### Try these examples:")
            gr.Examples(
                examples=BATCH_EXAMPLES,
                inputs=batch_input,
                outputs=batch_output,
                fn=predict_batch,
                cache_examples=False,
            )

        # Tab 3: About
        with gr.Tab("About"):
            gr.Markdown("""
            ## πŸ“š What are DGAs?
            
            **Domain Generation Algorithms (DGAs)** are techniques used by malware to generate 
            large numbers of pseudo-random domain names for C2 (command-and-control) communication.
            
            ### Why DGAs are dangerous:
            - **Evasion**: Traditional blacklists can't keep up with thousands of generated domains
            - **Resilience**: Even if some domains are blocked, malware can try others
            - **Stealth**: DGA domains look random, making detection challenging
            
            ### How this model works:
            
            1. **Character-level tokenization**: Breaks domain into individual characters
            2. **Transformer encoder**: Learns patterns in character sequences
            3. **Self-attention**: Detects unusual character combinations (e.g., `xqz`, `fgh`)
            4. **Classification**: Predicts if domain is legitimate or DGA-generated
            
            ### Key Features:
            - **High accuracy**: 96.78% F1 score on test set
            - **Fast inference**: <1ms per domain (GPU) or ~10ms (CPU)
            - **Lightweight**: Only 3.2M parameters
            - **Production-ready**: Trained on real-world malware domains
            
            ### Examples:
            
            **Legitimate domains** (structured, pronounceable):
            - `google.com`, `github.com`, `stackoverflow.com`
            - `api-docs.company.com`, `cdn-assets.example.org`
            
            **DGA domains** (random, unpronounceable):
            - `xjkd8f2h.com`, `qwfp93nx.net`, `h4fk29fd.org`
            - `kdjf92jd.info`, `zmxbv73k.biz`
            
            ---
            
            ### Technical Details:
            - **Model**: Custom Transformer Encoder
            - **Training data**: ExtraHop DGA dataset
            - **Framework**: PyTorch + HuggingFace Transformers
            - **Experiment tracking**: Weights & Biases
            
            ### Links:
            - [Model Card](https://huggingface.co/ccss17/dga-transformer-encoder)
            - [GitHub Repository](https://github.com/ccss17/DGA-Transformer-Encoder)
            - [ExtraHop DGA Dataset](https://github.com/extrahop/dga-training-data)
            
            ---
            
            **Built with ❀️ using PyTorch, HuggingFace, and Gradio**
            """)

    gr.Markdown("""
    ---
    <div style="text-align: center; color: #666; font-size: 14px;">
        <p>⚠️ <strong>Disclaimer</strong>: This model is for educational and research purposes. 
        Always use multiple detection methods in production security systems.</p>
        <p>Model accuracy: 96.78% | False positive rate: ~3%</p>
    </div>
    """)


if __name__ == "__main__":
    demo.launch()