Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -100,15 +100,27 @@ class GPT(nn.Module):
|
|
| 100 |
|
| 101 |
return logits, loss
|
| 102 |
|
| 103 |
-
#
|
| 104 |
def load_model(model_path):
|
| 105 |
config = GPTConfig()
|
| 106 |
model = GPT(config)
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
model.eval()
|
| 109 |
return model
|
| 110 |
|
| 111 |
-
|
|
|
|
| 112 |
enc = tiktoken.get_encoding('gpt2')
|
| 113 |
|
| 114 |
def generate_text(prompt, max_length=100, temperature=0.7):
|
|
|
|
| 100 |
|
| 101 |
return logits, loss
|
| 102 |
|
| 103 |
+
# Updated load_model function
|
| 104 |
def load_model(model_path):
|
| 105 |
config = GPTConfig()
|
| 106 |
model = GPT(config)
|
| 107 |
+
|
| 108 |
+
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
|
| 109 |
+
|
| 110 |
+
print("Checkpoint keys:", checkpoint.keys()) # Debug print
|
| 111 |
+
|
| 112 |
+
if 'model_state_dict' in checkpoint:
|
| 113 |
+
# If the checkpoint contains a 'model_state_dict' key, use that
|
| 114 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 115 |
+
else:
|
| 116 |
+
# Otherwise, try to load the state dict directly
|
| 117 |
+
model.load_state_dict(checkpoint)
|
| 118 |
+
|
| 119 |
model.eval()
|
| 120 |
return model
|
| 121 |
|
| 122 |
+
# Load the trained model
|
| 123 |
+
model = load_model('gpt_5000.pt') # Replace with the actual path to your .pt file
|
| 124 |
enc = tiktoken.get_encoding('gpt2')
|
| 125 |
|
| 126 |
def generate_text(prompt, max_length=100, temperature=0.7):
|