pro-grammer commited on
Commit
66e9630
·
verified ·
1 Parent(s): 8e09dca

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +18 -10
model.py CHANGED
@@ -114,16 +114,19 @@ class GPTLanguageModel(nn.Module):
114
 
115
  @torch.no_grad()
116
  def generate(self, idx, max_new_tokens, temperature=0.8, top_k=50, eos_token=None, max_consecutive_exclamations=2):
117
- generated_tokens = []
118
  consecutive_exclamations = 0
119
-
120
  for _ in range(max_new_tokens):
121
- # Forward pass
122
- logits, _ = self(idx)
 
 
 
 
 
123
  logits = logits[:, -1, :] / temperature
124
 
125
- # Top-k sampling
126
- top_k_logits, top_k_indices = torch.topk(logits, top_k)
127
  probs = F.softmax(top_k_logits, dim=-1)
128
  idx_next = top_k_indices[0, torch.multinomial(probs[0], num_samples=1)]
129
 
@@ -135,13 +138,13 @@ class GPTLanguageModel(nn.Module):
135
  else:
136
  consecutive_exclamations = 0
137
 
138
- generated_tokens.append(idx_next.item())
139
  idx = torch.cat((idx, idx_next.unsqueeze(0).unsqueeze(1)), dim=1)
140
 
141
  # Stop if EOS token is generated
142
  if eos_token is not None and idx_next.item() == eos_token:
143
  break
144
-
145
  return idx
146
 
147
  # Set up the device
@@ -161,8 +164,8 @@ weight_decay = 0.1
161
  # Create an instance of the model
162
  model = GPTLanguageModel(vocab_size, n_embd, block_size, n_layers, n_heads).to(device)
163
 
164
- # Load the trained weights
165
- model.load_state_dict(torch.load("model_weights.pth", map_location=device))
166
 
167
  # Set the model to evaluation mode
168
  model.eval()
@@ -178,6 +181,11 @@ max_new_tokens = 300
178
  temperature = 0.6 # Slightly lower temperature
179
  top_k = 40 # Adjust as needed
180
  # Generate text
 
 
 
 
 
181
  generated_text_idx = model.generate(context, max_new_tokens, temperature=temperature, top_k=top_k, eos_token=eos_token, max_consecutive_exclamations=2)
182
  generated_text = enc.decode(generated_text_idx[0].tolist())
183
 
 
114
 
115
  @torch.no_grad()
116
  def generate(self, idx, max_new_tokens, temperature=0.8, top_k=50, eos_token=None, max_consecutive_exclamations=2):
 
117
  consecutive_exclamations = 0
 
118
  for _ in range(max_new_tokens):
119
+ # Crop idx to the last block_size tokens if it exceeds block_size
120
+ idx_cond = idx[:, -self.block_size:]
121
+
122
+ # Get the predictions
123
+ logits, _ = self(idx_cond)
124
+
125
+ # Focus only on the last time step
126
  logits = logits[:, -1, :] / temperature
127
 
128
+ # Apply top-k sampling
129
+ top_k_logits, top_k_indices = torch.topk(logits, min(top_k, logits.size(-1)))
130
  probs = F.softmax(top_k_logits, dim=-1)
131
  idx_next = top_k_indices[0, torch.multinomial(probs[0], num_samples=1)]
132
 
 
138
  else:
139
  consecutive_exclamations = 0
140
 
141
+ # Append sampled index to the running sequence
142
  idx = torch.cat((idx, idx_next.unsqueeze(0).unsqueeze(1)), dim=1)
143
 
144
  # Stop if EOS token is generated
145
  if eos_token is not None and idx_next.item() == eos_token:
146
  break
147
+
148
  return idx
149
 
150
  # Set up the device
 
164
  # Create an instance of the model
165
  model = GPTLanguageModel(vocab_size, n_embd, block_size, n_layers, n_heads).to(device)
166
 
167
+ # Load the model (with weights_only=True for security)
168
+ model.load_state_dict(torch.load("model_weights.pth", map_location=device, weights_only=True))
169
 
170
  # Set the model to evaluation mode
171
  model.eval()
 
181
  temperature = 0.6 # Slightly lower temperature
182
  top_k = 40 # Adjust as needed
183
  # Generate text
184
+ # Load the model (with weights_only=True for security)
185
+ model.load_state_dict(torch.load("model_weights.pth", map_location=device, weights_only=True))
186
+
187
+ # Generate text
188
+ context = torch.tensor([enc.encode("Once upon a time there was a knight called Bob and he rode into his greatest battle yet")], dtype=torch.long, device=device)
189
  generated_text_idx = model.generate(context, max_new_tokens, temperature=temperature, top_k=top_k, eos_token=eos_token, max_consecutive_exclamations=2)
190
  generated_text = enc.decode(generated_text_idx[0].tolist())
191