Update README.md
Browse files
README.md
CHANGED
|
@@ -50,7 +50,7 @@ This model specializes in cybersecurity contexts. Predictions for unrelated cont
|
|
| 50 |
|
| 51 |
Always verify predictions with cybersecurity analysts before using in critical decision-making scenarios.
|
| 52 |
|
| 53 |
-
## How to Get Started with the Model
|
| 54 |
|
| 55 |
```python
|
| 56 |
import torch
|
|
@@ -104,6 +104,46 @@ print(f"Predicted GroupID: {predicted_class}")
|
|
| 104 |
```
|
| 105 |
Predicted GroupID: G0001
|
| 106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
## Training Details
|
| 109 |
|
|
|
|
| 50 |
|
| 51 |
Always verify predictions with cybersecurity analysts before using in critical decision-making scenarios.
|
| 52 |
|
| 53 |
+
## How to Get Started with the Model (Classification)
|
| 54 |
|
| 55 |
```python
|
| 56 |
import torch
|
|
|
|
| 104 |
```
|
| 105 |
Predicted GroupID: G0001
|
| 106 |
|
| 107 |
+
## How to Get Started with the Model (Embeddings)
|
| 108 |
+
|
| 109 |
+
```python
|
| 110 |
+
import torch
|
| 111 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 112 |
+
|
| 113 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 114 |
+
|
| 115 |
+
# Load your fine-tuned classification model
|
| 116 |
+
model_name = "selfconstruct3d/AttackGroup-MPNET"
|
| 117 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 118 |
+
classifier_model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)
|
| 119 |
+
|
| 120 |
+
def get_embedding(sentence):
|
| 121 |
+
classifier_model.eval()
|
| 122 |
+
|
| 123 |
+
encoding = tokenizer(
|
| 124 |
+
sentence,
|
| 125 |
+
truncation=True,
|
| 126 |
+
padding="max_length",
|
| 127 |
+
max_length=128,
|
| 128 |
+
return_tensors="pt"
|
| 129 |
+
)
|
| 130 |
+
input_ids = encoding["input_ids"].to(device)
|
| 131 |
+
attention_mask = encoding["attention_mask"].to(device)
|
| 132 |
+
|
| 133 |
+
with torch.no_grad():
|
| 134 |
+
outputs = classifier_model.mpnet(input_ids=input_ids, attention_mask=attention_mask)
|
| 135 |
+
cls_embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy().flatten()
|
| 136 |
+
|
| 137 |
+
return cls_embedding
|
| 138 |
+
|
| 139 |
+
# Example explicitly:
|
| 140 |
+
sentence = "APT38 has used phishing emails with malicious links to distribute malware."
|
| 141 |
+
embedding = get_embedding(sentence)
|
| 142 |
+
print("Embedding shape:", embedding.shape)
|
| 143 |
+
print("Embedding values:", embedding)
|
| 144 |
+
```
|
| 145 |
+
|
| 146 |
+
|
| 147 |
|
| 148 |
## Training Details
|
| 149 |
|