dga-detector / model.py
ccss17's picture
Fix: Include custom model code for HF Spaces deployment
5d1d43b
"""DGA Detection Model using Transformer Encoder.
This model treats domain names as sequences of characters and uses a Transformer
encoder to learn patterns that distinguish DGA (algorithmically generated) domains
from legitimate ones.
"""
from __future__ import annotations
from typing import Optional
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import SequenceClassifierOutput
from charset import PAD, VOCAB_SIZE
NUM_CLASSES = 2
class DGAEncoder(nn.Module):
"""Transformer encoder for DGA (Domain Generation Algorithm) detection."""
def __init__(
self,
*,
vocab_size: int,
max_len: int = 64,
d_model: int = 256,
nhead: int = 8,
num_layers: int = 4,
dropout: float = 0.1,
ffn_mult: int = 4,
) -> None:
super().__init__()
self.tok = nn.Embedding(vocab_size, d_model, padding_idx=PAD)
self.pos = nn.Embedding(max_len, d_model)
self.register_buffer(
"position_ids",
torch.arange(max_len).unsqueeze(0),
persistent=False,
)
enc_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=ffn_mult * d_model,
dropout=dropout,
batch_first=True,
norm_first=True,
)
self.enc = nn.TransformerEncoder(enc_layer, num_layers=num_layers)
self.norm = nn.LayerNorm(d_model)
self.clf = nn.Linear(d_model, NUM_CLASSES)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass through the encoder."""
b, L = x.shape
pos = self.position_ids[:, :L].expand(b, L)
h = self.tok(x) + self.pos(pos)
h = self.enc(h)
cls = self.norm(h[:, 0])
return self.clf(cls)
class DGAEncoderConfig(PretrainedConfig):
"""Configuration for DGAEncoder compatible with HuggingFace Transformers."""
model_type = "dga_encoder"
def __init__(
self,
vocab_size: int = VOCAB_SIZE,
max_len: int = 64,
d_model: int = 256,
nhead: int = 8,
num_layers: int = 4,
dropout: float = 0.1,
ffn_mult: int = 4,
num_labels: int = 2,
**kwargs,
):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.max_len = max_len
self.d_model = d_model
self.nhead = nhead
self.num_layers = num_layers
self.dropout = dropout
self.ffn_mult = ffn_mult
self.num_labels = num_labels
class DGAEncoderForSequenceClassification(PreTrainedModel):
"""HuggingFace-compatible wrapper around DGAEncoder."""
config_class = DGAEncoderConfig
def __init__(self, config: DGAEncoderConfig):
super().__init__(config)
self.config = config
self.encoder = DGAEncoder(
vocab_size=config.vocab_size,
max_len=config.max_len,
d_model=config.d_model,
nhead=config.nhead,
num_layers=config.num_layers,
dropout=config.dropout,
ffn_mult=config.ffn_mult,
)
self.post_init()
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
return_dict: Optional[bool] = None,
**kwargs,
):
"""Forward pass compatible with HF Trainer."""
return_dict = (
return_dict
if return_dict is not None
else self.config.use_return_dict
)
logits = self.encoder(input_ids)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
logits.view(-1, self.config.num_labels), labels.view(-1)
)
if not return_dict:
output = (logits,)
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=None,
attentions=None,
)