Spaces:
Running
Running
Update model.py
Browse files
model.py
CHANGED
|
@@ -114,16 +114,19 @@ class GPTLanguageModel(nn.Module):
|
|
| 114 |
|
| 115 |
@torch.no_grad()
|
| 116 |
def generate(self, idx, max_new_tokens, temperature=0.8, top_k=50, eos_token=None, max_consecutive_exclamations=2):
|
| 117 |
-
generated_tokens = []
|
| 118 |
consecutive_exclamations = 0
|
| 119 |
-
|
| 120 |
for _ in range(max_new_tokens):
|
| 121 |
-
#
|
| 122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
logits = logits[:, -1, :] / temperature
|
| 124 |
|
| 125 |
-
#
|
| 126 |
-
top_k_logits, top_k_indices = torch.topk(logits, top_k)
|
| 127 |
probs = F.softmax(top_k_logits, dim=-1)
|
| 128 |
idx_next = top_k_indices[0, torch.multinomial(probs[0], num_samples=1)]
|
| 129 |
|
|
@@ -135,13 +138,13 @@ class GPTLanguageModel(nn.Module):
|
|
| 135 |
else:
|
| 136 |
consecutive_exclamations = 0
|
| 137 |
|
| 138 |
-
|
| 139 |
idx = torch.cat((idx, idx_next.unsqueeze(0).unsqueeze(1)), dim=1)
|
| 140 |
|
| 141 |
# Stop if EOS token is generated
|
| 142 |
if eos_token is not None and idx_next.item() == eos_token:
|
| 143 |
break
|
| 144 |
-
|
| 145 |
return idx
|
| 146 |
|
| 147 |
# Set up the device
|
|
@@ -161,8 +164,8 @@ weight_decay = 0.1
|
|
| 161 |
# Create an instance of the model
|
| 162 |
model = GPTLanguageModel(vocab_size, n_embd, block_size, n_layers, n_heads).to(device)
|
| 163 |
|
| 164 |
-
# Load the
|
| 165 |
-
model.load_state_dict(torch.load("model_weights.pth", map_location=device))
|
| 166 |
|
| 167 |
# Set the model to evaluation mode
|
| 168 |
model.eval()
|
|
@@ -178,6 +181,11 @@ max_new_tokens = 300
|
|
| 178 |
temperature = 0.6 # Slightly lower temperature
|
| 179 |
top_k = 40 # Adjust as needed
|
| 180 |
# Generate text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
generated_text_idx = model.generate(context, max_new_tokens, temperature=temperature, top_k=top_k, eos_token=eos_token, max_consecutive_exclamations=2)
|
| 182 |
generated_text = enc.decode(generated_text_idx[0].tolist())
|
| 183 |
|
|
|
|
| 114 |
|
| 115 |
@torch.no_grad()
|
| 116 |
def generate(self, idx, max_new_tokens, temperature=0.8, top_k=50, eos_token=None, max_consecutive_exclamations=2):
|
|
|
|
| 117 |
consecutive_exclamations = 0
|
|
|
|
| 118 |
for _ in range(max_new_tokens):
|
| 119 |
+
# Crop idx to the last block_size tokens if it exceeds block_size
|
| 120 |
+
idx_cond = idx[:, -self.block_size:]
|
| 121 |
+
|
| 122 |
+
# Get the predictions
|
| 123 |
+
logits, _ = self(idx_cond)
|
| 124 |
+
|
| 125 |
+
# Focus only on the last time step
|
| 126 |
logits = logits[:, -1, :] / temperature
|
| 127 |
|
| 128 |
+
# Apply top-k sampling
|
| 129 |
+
top_k_logits, top_k_indices = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 130 |
probs = F.softmax(top_k_logits, dim=-1)
|
| 131 |
idx_next = top_k_indices[0, torch.multinomial(probs[0], num_samples=1)]
|
| 132 |
|
|
|
|
| 138 |
else:
|
| 139 |
consecutive_exclamations = 0
|
| 140 |
|
| 141 |
+
# Append sampled index to the running sequence
|
| 142 |
idx = torch.cat((idx, idx_next.unsqueeze(0).unsqueeze(1)), dim=1)
|
| 143 |
|
| 144 |
# Stop if EOS token is generated
|
| 145 |
if eos_token is not None and idx_next.item() == eos_token:
|
| 146 |
break
|
| 147 |
+
|
| 148 |
return idx
|
| 149 |
|
| 150 |
# Set up the device
|
|
|
|
| 164 |
# Create an instance of the model
|
| 165 |
model = GPTLanguageModel(vocab_size, n_embd, block_size, n_layers, n_heads).to(device)
|
| 166 |
|
| 167 |
+
# Load the model (with weights_only=True for security)
|
| 168 |
+
model.load_state_dict(torch.load("model_weights.pth", map_location=device, weights_only=True))
|
| 169 |
|
| 170 |
# Set the model to evaluation mode
|
| 171 |
model.eval()
|
|
|
|
| 181 |
temperature = 0.6 # Slightly lower temperature
|
| 182 |
top_k = 40 # Adjust as needed
|
| 183 |
# Generate text
|
| 184 |
+
# Load the model (with weights_only=True for security)
|
| 185 |
+
model.load_state_dict(torch.load("model_weights.pth", map_location=device, weights_only=True))
|
| 186 |
+
|
| 187 |
+
# Generate text
|
| 188 |
+
context = torch.tensor([enc.encode("Once upon a time there was a knight called Bob and he rode into his greatest battle yet")], dtype=torch.long, device=device)
|
| 189 |
generated_text_idx = model.generate(context, max_new_tokens, temperature=temperature, top_k=top_k, eos_token=eos_token, max_consecutive_exclamations=2)
|
| 190 |
generated_text = enc.decode(generated_text_idx[0].tolist())
|
| 191 |
|