Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -9,15 +9,15 @@ def load_model():
|
|
| 9 |
"""Load the trained GPT model"""
|
| 10 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 11 |
GPT_CONFIG = {
|
| 12 |
-
vocab_size
|
| 13 |
-
n_heads
|
| 14 |
-
n_layers
|
| 15 |
-
head_size
|
| 16 |
-
n_embd
|
| 17 |
-
block_size
|
| 18 |
-
dropout
|
| 19 |
-
learning_rate
|
| 20 |
-
weight_decay
|
| 21 |
}
|
| 22 |
model = GPTLanguageModel(GPT_CONFIG)
|
| 23 |
model.load_state_dict(torch.load("model_weights.pth", map_location=device))
|
|
|
|
| 9 |
"""Load the trained GPT model"""
|
| 10 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 11 |
GPT_CONFIG = {
|
| 12 |
+
"vocab_size" : 50257,
|
| 13 |
+
"n_heads" : 8,
|
| 14 |
+
"n_layers" : 6,
|
| 15 |
+
"head_size" : 64,
|
| 16 |
+
"n_embd" : 512,
|
| 17 |
+
"block_size" : 128,
|
| 18 |
+
"dropout" : 0.1,
|
| 19 |
+
"learning_rate" : 3e-4,
|
| 20 |
+
"weight_decay" : 0.1,
|
| 21 |
}
|
| 22 |
model = GPTLanguageModel(GPT_CONFIG)
|
| 23 |
model.load_state_dict(torch.load("model_weights.pth", map_location=device))
|